コード例 #1
0
def train(args):

    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    ndf = 64
    ngf = 64
    nz = 256
    nlr = 0.0002
    nbeta1 = 0.5
    use_cuda = True
    multi_gpu = False
    dataloader_workers = 8
    current_iteration = 0
    save_interval = 100
    saved_model_folder, saved_image_folder = get_dir(args)

    device = torch.device("cpu")
    if use_cuda:
        device = torch.device("cuda:0")

    transform_list = [
        transforms.Resize((int(im_size), int(im_size))),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]
    trans = transforms.Compose(transform_list)

    dataset = ImageFolder(root=data_root, transform=trans)
    dataloader = iter(
        DataLoader(dataset,
                   batch_size=batch_size,
                   shuffle=False,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=dataloader_workers,
                   pin_memory=True))

    netD = Discriminator(ndf=ndf, im_size=im_size)
    netD.apply(weights_init)

    net_decoder = SimpleDecoder(ndf * 16)
    net_decoder.apply(weights_init)

    net_decoder.to(device)

    ckpt = torch.load(checkpoint)
    netD.load_state_dict(ckpt['d'])
    netD.to(device)
    netD.eval()

    optimizerG = optim.Adam(net_decoder.parameters(),
                            lr=nlr,
                            betas=(nbeta1, 0.999))

    log_rec_loss = 0

    for iteration in tqdm(range(current_iteration, total_iterations + 1)):
        real_image = next(dataloader)
        real_image = real_image.to(device)
        current_batch_size = real_image.size(0)

        net_decoder.zero_grad()

        feat = netD.get_feat(real_image)
        g_imag = net_decoder(feat)

        target_image = F.interpolate(real_image, g_imag.shape[2])

        rec_loss = percept(g_imag, target_image).sum()

        rec_loss.backward()

        optimizerG.step()

        log_rec_loss += rec_loss.item()

        if iteration % 100 == 0:
            print("lpips loss d: %.5f " % (log_rec_loss / 100))
            log_rec_loss = 0

        if iteration % (save_interval * 10) == 0:

            with torch.no_grad():
                vutils.save_image(
                    torch.cat([target_image, g_imag]).add(1).mul(0.5),
                    saved_image_folder + '/rec_%d.jpg' % iteration)

        if iteration % (save_interval *
                        50) == 0 or iteration == total_iterations:
            torch.save(
                {
                    'd': netD.state_dict(),
                    'dec': net_decoder.state_dict()
                }, saved_model_folder + '/%d.pth' % iteration)
コード例 #2
0
    args = parse_args()
    nz = args.nz
    batch_size = args.batch_size
    epochs = args.epochs
    gpu = args.gpu

    train, _ = datasets.get_mnist(withlabel=False, ndim=3)
    train_iter = iterators.SerialIterator(train, batch_size)
    z_iter = iterators.RandomNoiseIterator(UniformNoiseGenerator(-1, 1, nz),
                                           batch_size)

    optimizer_generator = optimizers.Adam(alpha=1e-3, beta1=0.5)
    optimizer_discriminator = optimizers.Adam(alpha=2e-4, beta1=0.5)

    optimizer_generator.setup(Generator(nz))
    optimizer_discriminator.setup(Discriminator(train.shape[2:]))

    updater = updater.GenerativeAdversarialUpdater(
        iterator=train_iter,
        noise_iterator=z_iter,
        optimizer_generator=optimizer_generator,
        optimizer_discriminator=optimizer_discriminator,
        device=gpu)

    trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'gen/loss', 'dis/loss']))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.GeneratorSample(), trigger=(1, 'epoch'))
    trainer.run()
コード例 #3
0
def TrainSourceTargetModel_exp(options):
    #defining the models
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    options['device'] = device
    cEnc = ConvEncoder(latentDimension=options['latentD']).to(device)
    disc = Discriminator(latentDimension=options['latentD']).to(device)
    classify = Classifier(latentDimension=options['latentD']).to(device)

    # defining loss functions
    MSELoss = nn.MSELoss().to(device)  # mean squared error loss
    BCELoss = nn.BCELoss().to(device)  # binary cross-entropy
    NLLLoss = nn.NLLLoss().to(device)  # negative-log likelihood
    CELoss = nn.CrossEntropyLoss().to(device)  # cross entropy loss

    # loss log data
    lossData_a, lossData_b = [], []

    optimizer_classification = optim.Adam(itertools.chain(
        classify.parameters()),
                                          lr=options['lrA'])
    optimizer_encoder = optim.Adam(itertools.chain(cEnc.parameters()),
                                   lr=options['lrA'])
    optimizer_discriminator = optim.Adam(itertools.chain(disc.parameters()),
                                         lr=options['lrB'])

    sourceLogText, targetLogText = "", ""
    sourceTrainLoader = options['sourceTrainLoader']
    targetTrainLoader = options['targetTrainLoader']

    for epochIdx in range(options['epochs']):

        # defining the data loaders
        _ones = Variable(torch.FloatTensor(options['batchSize'], 1).fill_(1.0),
                         requires_grad=False).to(device)
        _zeros = Variable(torch.FloatTensor(options['batchSize'],
                                            1).fill_(0.0),
                          requires_grad=False).to(device)

        for (batchData,
             batchLabels), (targetBatchData,
                            targetBatchLabels) in zip(sourceTrainLoader,
                                                      targetTrainLoader):

            #SOURCE DATA

            # data setup
            _batchSize = batchData.shape[0]
            _targetBatchSize = targetBatchData.shape[0]
            #source
            batchData = batchData.to(device)
            batchLabels = batchLabels.to(device)
            #target
            targetBatchData = targetBatchData.to(device)
            targetBatchLabels = targetBatchLabels.to(device)

            #generate synthetic sample from the feature space
            latentSpaceSample_A, latentSpaceClasses_A = getSampleFromLatentSpace(
                [_batchSize, options['latentD']], options)
            latentSpaceSample_B, latentSpaceClasses_B = getSampleFromLatentSpace(
                [_batchSize, options['latentD']], options)

            # cEnc pass
            encodedDataBatch = cEnc(batchData)
            targetEncodedDataBatch = cEnc(targetBatchData)

            #classification pass
            classesPrediction = classify(encodedDataBatch)
            sampleClassesPrediction = classify(latentSpaceSample_A)

            #discriminator pass
            sourceDiscOutput = disc(encodedDataBatch)
            targetDiscOutput = disc(targetEncodedDataBatch)

            # optimzation step I -- classification
            optimizer_classification.zero_grad()
            #--- first loss function
            loss_a = CELoss(sampleClassesPrediction, latentSpaceClasses_A)
            loss_a.backward()
            optimizer_classification.step()

            #optimization step II -- encoder
            optimizer_encoder.zero_grad()
            loss_b = CELoss(classesPrediction, batchLabels)
            loss_b.backward()
            optimizer_encoder.step()

            # loss_c=BCELoss(targetDiscDataInput,_zeros[:_targetBatchSize])+\
            #        BCELoss(sampleDiscOutputB,_ones[:_batchSize])

            #
            # #discriminator pass
            # discDataInput=Variable(encodedDataBatch.view(_batchSize,options['latentD']).cpu().data,requires_grad=False).to(device)
            # discDataOutput=disc(discDataInput)
            # targetDiscDataInput=Variable(targetEncodedDataBatch.view(_targetBatchSize,options['latentD']).cpu().data,requires_grad=False).to(device)
            # targetDiscDataInput=disc(targetDiscDataInput)
            # sampleDiscOutputB=disc(latentSpaceSample_B)
            # # optimization step II
            # #---train the discriminator, 1/0 is real/fake data
            # optimizer_b.zero_grad()

            #---second loss function
            # loss_b=BCELoss(discDataOutput,_zeros[:_batchSize])+\
            #        BCELoss(targetDiscDataInput,_zeros[:_targetBatchSize])+\
            #        BCELoss(sampleDiscOutputB,_ones[:_batchSize])
            # loss_b.backward()
            # optimizer_b.step()
            # lossData_b.append(loss_b.data.item())

        ####
        #### End of an epoch
        ####

        sourceLogText += validateModel(epochIdx,
                                       options,
                                       models=[cEnc, classify, disc])
        targetLogText += validateModel(epochIdx,
                                       options,
                                       source=False,
                                       models=[cEnc, classify, disc])

        # end of an epoch - CHECK ACCURACY ON TEST SET

    outputs = {
        'lossA': lossData_a,
        'lossB': lossData_b,
        'encoder': cEnc,
        'disc': disc,
        'classifier': classify,
        'sourceLogText': sourceLogText,
        'targetLogText': targetLogText,
    }
    return outputs
コード例 #4
0
# Create optimizers for the generators and discriminators
g_optimizer = optim.Adam(g_params, lr, [beta1, beta2])
# d_x_optimizer = optim.Adam(D_X.parameters(), lr, [beta1, beta2])
# d_y_optimizer = optim.Adam(D_Y.parameters(), lr, [beta1, beta2])
#-----------------------#

decoder.eval()
decoder2.train()
vgg.eval()

styletransfer = stylenet.Net(vgg, decoder)
style_iter = iter(style)
content_iter = iter(content)

# s_test = iter(style_test); c_test = iter(content_test);
dOne = Discriminator.define_D(3, 8, netD='pixel')
dOne.to(device)

# alpha = config['sty']['alpha']
import random  #for alpha
for i in tqdm(range(config['exp']['max_iter'])):

    adjust_learning_rate(g_optimizer, iteration_count=i)
    content_images = next(content_iter)[0].to(device)
    style_images = next(style_iter)[0].to(device)

    alpha = round(random.uniform(0, 1), 3)
    with torch.no_grad():
        sty_ft = vgg(style_images)
        cont_ft = vgg(content_images)
        feat = adaptive_instance_normalization(content_feat=cont_ft,
コード例 #5
0
parser.add_argument('--default_rate', type=float, default=0.5, help='Set the lambda weight between GAN loss and Recon loss after curriculum period. We used the 0.5 weight.')

parser.add_argument('--sample_step', type=int, default=1000)
parser.add_argument('--model_save_step', type=int, default=10000)

opt = parser.parse_args()
print(opt)

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

###### Definition of variables ######
# Networks
netG_A2B = Generator()
netG_B2A = Generator()
netD_A = Discriminator()
netD_B = Discriminator()

print('---------- Networks initialized -------------')
print_network(netG_A2B)
print_network(netG_B2A)
print_network(netD_A)
print_network(netD_B)
print('-----------------------------------------------')

if opt.cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()
コード例 #6
0
from torch.utils.data import DataLoader

from models import Generator, Discriminator, VGG16
from model_wrapper import ModelWrapper
import data

if __name__ == '__main__':
    # Init models
    if args.load_generator_network is None:
        generator = Generator(channels_factor=args.channel_factor)
    else:
        generator = torch.load(args.load_generator_network)
        if isinstance(generator, nn.DataParallel):
            generator = generator.module
    if args.load_discriminator_network is None:
        discriminator = Discriminator(channel_factor=args.channel_factor)
    else:
        discriminator = torch.load(args.load_discriminator_network)
        if isinstance(discriminator, nn.DataParallel):
            discriminator = discriminator.module
    vgg16 = VGG16(args.load_pretrained_vgg16)

    # Init data parallel
    if args.use_data_parallel:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        vgg16 = nn.DataParallel(vgg16)

    # Init optimizers
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
コード例 #7
0
# ...

seed = random.randint(1, 10000)
print("Random Seed: ", seed)
torch.manual_seed(seed)
if opt.cuda:
    torch.cuda.manual_seed(seed)

# build network
print('==>building network...')
generator = Generator(in_nc=opt.in_nc,
                      mid_nc=opt.mid_nc,
                      out_nc=opt.out_nc,
                      scale_factor=opt.scale_factor,
                      num_RRDBS=opt.num_RRDBs)
discriminator = Discriminator()
feature_extractor = FeatureExtractor()

# loss

# content loss
if opt.content_loss_type == 'L1_Charbonnier':
    content_loss = L1_Charbonnier_loss()
elif opt.content_loss_type == 'L1':
    content_loss = torch.nn.L1Loss()
elif opt.content_loss_type == 'L2':
    content_loss = torch.nn.MSELoss()

# pixel loss
if opt.pixel_loss_type == 'L1':
    pixel_loss = torch.nn.L1Loss()
コード例 #8
0
     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Prepare data
trainset0 = datasets.CIFAR10(root='./data/', train=True, download=False, transform=transform)
trainloader0 = torch.utils.data.DataLoader(trainset0, batch_size=args.batch_size, shuffle=True, num_workers=2)
trainset1 = datasets.CelebA('./data/celeba', transform=transform)
trainloader1 = torch.utils.data.DataLoader(trainset1, batch_size=args.batch_size, shuffle=True, num_workers=2)

current_loader = None
fullLoader = None 

# Create N single dataset loaders
loaders = [multiLoaders(args, x) for x in (trainset0, trainset1)]

model_d = Discriminator()
model_g = Generator()

# Generate fixed noise 
fixed_noise = torch.FloatTensor(SAMPLE_SIZE, args.nz, 1, 1).normal_(0, 1)

# Send model to the current device
model_d.cuda()
model_g.cuda()
fixed_noise = Variable(fixed_noise).cuda()

meta_iteration = 0

# Generate random sample of singles loader from each dataset
for n in range(NUM_DATASET - NUM_FULL_LOADER_DATASET):
    print("sample: generated from dataset: {}".format(eval("trainset"+str(n)).__class__.__name__))
コード例 #9
0
def caculate_fitness_for_first_time(mask_input, gpu_id, fitness_id,
                                    A2B_or_B2A):

    ###### Definition of variables ######
    torch.cuda.set_device(gpu_id)
    #print("GPU_ID is%d\n"%(gpu_id))
    if A2B_or_B2A == 'A2B':
        netG_A2B = Generator(opt.input_nc, opt.output_nc)
        netD_B = Discriminator(opt.output_nc)
        netG_A2B.cuda(gpu_id)
        netD_B.cuda(gpu_id)
        model = Generator(opt.input_nc, opt.output_nc)
        model.cuda(gpu_id)
        netG_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))
        netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))
        model.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))
        model.eval()
        netD_B.eval()
        netG_A2B.eval()

    elif A2B_or_B2A == 'B2A':
        netG_B2A = Generator(opt.output_nc, opt.input_nc)
        netD_A = Discriminator(opt.input_nc)
        netG_B2A.cuda(gpu_id)
        netD_A.cuda(gpu_id)
        model = Generator(opt.input_nc, opt.output_nc)
        model.cuda(gpu_id)
        netG_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))
        netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
        model.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))
        model.eval()
        netD_A.eval()
        netG_B2A.eval()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    fitness = 0
    cfg_mask = compute_layer_mask(mask_input, mask_chns)
    cfg_full_mask = [y for x in cfg_mask for y in x]
    cfg_full_mask = np.array(cfg_full_mask)
    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask[cfg_id]

    for m in model.modules():
        if isinstance(m, nn.Conv2d):

            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask[cfg_id]
            continue

    # Dataset loader
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    lamda_loss_ID = 5.0
    lamda_loss_G = 1.0
    lamda_loss_cycle = 10.0

    with torch.no_grad():

        transforms_ = [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        dataloader = DataLoader(ImageDataset(opt.dataroot,
                                             transforms_=transforms_,
                                             mode='val'),
                                batch_size=opt.batchSize,
                                shuffle=False,
                                drop_last=True)

        Loss_resemble_G = 0
        if A2B_or_B2A == 'A2B':
            for i, batch in enumerate(dataloader):
                # Set model input
                real_A = Variable(input_A.copy_(batch['A']))

                # GAN loss
                fake_B = model(real_A)
                fake_B_full_model = netG_A2B(real_A)

                # Fake loss
                pred_fake = netD_B(fake_B.detach())

                pred_fake_full = netD_B(fake_B_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                Loss_resemble_G = Loss_resemble_G + loss_D_fake

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask.shape) - cfg_full_mask) * lambda_prune
            print("A2B first generation")
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask.shape) - cfg_full_mask)))
            print("fitness is %f \n" % (fitness))

            current_fitness_A2B[fitness_id] = fitness.item()

        elif A2B_or_B2A == 'B2A':
            for i, batch in enumerate(dataloader):

                real_B = Variable(input_B.copy_(batch['B']))

                fake_A = model(real_B)
                fake_A_full_model = netG_B2A(real_B)

                pred_fake = netD_A(fake_A.detach())

                pred_fake_full = netD_A(fake_A_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                Loss_resemble_G = Loss_resemble_G + loss_D_fake

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask.shape) - cfg_full_mask) * lambda_prune
            print("B2A first generation")
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask.shape) - cfg_full_mask)))
            print("fitness is %f \n" % (fitness))

            current_fitness_B2A[fitness_id] = fitness.item()
コード例 #10
0
def caculate_fitness(mask_input_A2B, mask_input_B2A, gpu_id, fitness_id,
                     A2B_or_B2A):

    torch.cuda.set_device(gpu_id)
    #print("GPU_ID is%d\n"%(gpu_id))

    model_A2B = Generator(opt.input_nc, opt.output_nc)
    model_B2A = Generator(opt.input_nc, opt.output_nc)

    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    netD_A.cuda(gpu_id)
    netD_B.cuda(gpu_id)
    model_A2B.cuda(gpu_id)
    model_B2A.cuda(gpu_id)

    model_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))
    model_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))
    netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
    netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    fitness = 0
    cfg_mask_A2B = compute_layer_mask(mask_input_A2B, mask_chns)
    cfg_mask_B2A = compute_layer_mask(mask_input_B2A, mask_chns)
    cfg_full_mask_A2B = [y for x in cfg_mask_A2B for y in x]
    cfg_full_mask_A2B = np.array(cfg_full_mask_A2B)
    cfg_full_mask_B2A = [y for x in cfg_mask_B2A for y in x]
    cfg_full_mask_B2A = np.array(cfg_full_mask_B2A)
    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask_A2B[cfg_id]

    for m in model_A2B.modules():
        if isinstance(m, nn.Conv2d):

            #print("conv2d")
            #print(m.weight.data.shape)
            #out_channels = m.weight.data.shape[0]
            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask_A2B[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask_A2B[cfg_id]
            continue

    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask_B2A[cfg_id]

    for m in model_B2A.modules():
        if isinstance(m, nn.Conv2d):

            #print("conv2d")
            #print(m.weight.data.shape)
            #out_channels = m.weight.data.shape[0]
            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask_B2A[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask_B2A[cfg_id]
            continue

    # Dataset loader
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    lamda_loss_ID = 5.0
    lamda_loss_G = 1.0
    lamda_loss_cycle = 10.0
    optimizer_G = torch.optim.Adam(itertools.chain(
        filter(lambda p: p.requires_grad, model_A2B.parameters()),
        filter(lambda p: p.requires_grad, model_B2A.parameters())),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    transforms_ = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True,
                                         mode='train'),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            drop_last=True)

    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = model_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * lamda_loss_ID  #initial 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = model_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * lamda_loss_ID  #initial 5.0

            # GAN loss
            fake_B = model_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            fake_A = model_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            # Cycle loss
            recovered_A = model_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * lamda_loss_cycle  #initial 10.0

            recovered_B = model_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * lamda_loss_cycle  #initial 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()

    with torch.no_grad():

        transforms_ = [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        dataloader = DataLoader(ImageDataset(opt.dataroot,
                                             transforms_=transforms_,
                                             mode='val'),
                                batch_size=opt.batchSize,
                                shuffle=False,
                                drop_last=True)

        Loss_resemble_G = 0
        if A2B_or_B2A == 'A2B':
            netG_A2B = Generator(opt.output_nc, opt.input_nc)
            netD_B = Discriminator(opt.output_nc)

            netG_A2B.cuda(gpu_id)
            netD_B.cuda(gpu_id)

            model_A2B.eval()
            netD_B.eval()
            netG_A2B.eval()

            netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))
            netG_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))

            for i, batch in enumerate(dataloader):

                real_A = Variable(input_A.copy_(batch['A']))

                fake_B = model_A2B(real_A)
                fake_B_full_model = netG_A2B(real_A)
                recovered_A = model_B2A(fake_B)

                pred_fake = netD_B(fake_B.detach())

                pred_fake_full = netD_B(fake_B_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                cycle_loss = criterion_cycle(recovered_A,
                                             real_A) * lamda_loss_cycle
                Loss_resemble_G = Loss_resemble_G + loss_D_fake + cycle_loss

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask_A2B.shape) -
                cfg_full_mask_A2B) * lambda_prune

            print('A2B')
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask_A2B)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask_A2B.shape) - cfg_full_mask_A2B)))
            print("fitness is %f \n" % (fitness))

            current_fitness_A2B[fitness_id] = fitness.item()

        if A2B_or_B2A == 'B2A':
            netG_B2A = Generator(opt.output_nc, opt.input_nc)
            netD_A = Discriminator(opt.output_nc)

            netG_B2A.cuda(gpu_id)
            netD_A.cuda(gpu_id)

            model_B2A.eval()
            netD_A.eval()
            netG_B2A.eval()

            netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
            netG_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))

            for i, batch in enumerate(dataloader):

                real_B = Variable(input_B.copy_(batch['B']))

                fake_A = model_B2A(real_B)
                fake_A_full_model = netG_B2A(real_B)
                recovered_B = model_A2B(fake_A)

                pred_fake = netD_A(fake_A.detach())

                pred_fake_full = netD_A(fake_A_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                cycle_loss = criterion_cycle(recovered_B,
                                             real_B) * lamda_loss_cycle
                Loss_resemble_G = Loss_resemble_G + loss_D_fake + cycle_loss

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask_B2A.shape) -
                cfg_full_mask_B2A) * lambda_prune

            print('B2A')
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask_B2A)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask_B2A.shape) - cfg_full_mask_B2A)))
            print("fitness is %f \n" % (fitness))

            current_fitness_B2A[fitness_id] = fitness.item()
コード例 #11
0
loss_object = tf.keras.losses.BinaryCrossEntropy(from_logits=True)


def generator_loss(disc_generated_output, gen_output, target):

    gan_loss = loss_object(tf.ones_like(disc_generated_output),
                           disc_generated_output)

    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss, gan_loss, l1_loss


discriminator = Discriminator()

tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)

plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap="RdBu_r")
plt.colorbar()


def discrminator_loss(disc_real_output, disc_generated_output):

    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output),
                                 disc_generated_output)
コード例 #12
0
ファイル: usps_mnist.py プロジェクト: wogong/pytorch-adda
                                      dataset_root=params.dataset_root,
                                      batch_size=params.batch_size,
                                      train=True)
    tgt_data_loader = get_data_loader(params.tgt_dataset,
                                      dataset_root=params.dataset_root,
                                      batch_size=params.batch_size,
                                      train=True)

    # load models
    src_encoder = init_model(net=LeNetEncoder(),
                             restore=params.src_encoder_restore)
    src_classifier = init_model(net=LeNetClassifier(),
                                restore=params.src_classifier_restore)
    tgt_encoder = init_model(net=LeNetEncoder(),
                             restore=params.tgt_encoder_restore)
    critic = init_model(Discriminator(), restore=params.d_model_restore)

    # train source model
    print("=== Training classifier for source domain ===")
    print(">>> Source Encoder <<<")
    print(src_encoder)
    print(">>> Source Classifier <<<")
    print(src_classifier)

    if not (src_encoder.restored and src_classifier.restored
            and params.src_model_trained):
        src_encoder, src_classifier = train_src(src_encoder, src_classifier,
                                                src_data_loader, params)

    # eval source model
    print("=== Evaluating classifier for source domain ===")
コード例 #13
0
ファイル: train.py プロジェクト: wenchenhui/MDvsFA
import torch 
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

initialize_logger('./logs')

args = para_parser()

cuda = True if torch.cuda.is_available() else False
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


# Models
disc_net = Discriminator()
gen1_net = Generator1_CAN8(is_anm=args.is_anm)
gen2_net = Generator2_UCAN64(is_anm=args.is_anm)

if cuda:
    disc_net.to(device='cuda: 0')
    gen1_net.to(device='cuda: 0')
    gen2_net.to(device='cuda: 0')

if args.parallel and cuda:
    disc_net = torch.nn.DataParallel(Discriminator())
    gen1_net = torch.nn.DataParallel(Generator1_CAN8(is_anm=args.is_anm))
    gen2_net = torch.nn.DataParallel(Generator2_UCAN64(is_anm=args.is_anm))

    disc_net.to(device='cuda: 0')
    gen1_net.to(device='cuda: 0')
コード例 #14
0
axis_y = np.arange(0, height)
grid_axes = np.array(np.meshgrid(axis_x, axis_y))
grid_axes = np.transpose(grid_axes, (1, 2, 0))
from scipy.spatial import Delaunay

tri = Delaunay(grid_axes.reshape([-1, 2]))
faces = tri.simplices.copy()
F = DiagramlayerToplevel().init_filtration(faces)
diagramlayerToplevel = DiagramlayerToplevel.apply
''' '''

###### Definition of variables ######
# Networks
netG_A2B = Generator(opt.input_nc, opt.output_nc)
netG_B2A = Generator(opt.output_nc, opt.input_nc)
netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.output_nc)

if opt.cuda:
    # netG_A2B.cuda()
    netG_A2B.to(torch.device('cuda'))
    # netG_B2A.cuda()
    netG_B2A.to(torch.device('cuda'))
    # netD_A.cuda()
    netD_A.to(torch.device('cuda'))
    # netD_B.cuda()
    netD_B.to(torch.device('cuda'))

    netG_A2B = nn.DataParallel(netG_A2B, device_ids=[0, 1, 2])
    netG_B2A = nn.DataParallel(netG_B2A, device_ids=[0, 1, 2])
    netD_A = nn.DataParallel(netD_A, device_ids=[0, 1, 2])
コード例 #15
0
def train(args):

    # Device Configuration #
    device = torch.device(
        f'cuda:{args.gpu_num}' if torch.cuda.is_available() else 'cpu')

    # Fix Seed for Reproducibility #
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Samples, Plots, Weights and CSV Path #
    paths = [
        args.samples_path, args.plots_path, args.weights_path, args.csv_path
    ]
    for path in paths:
        make_dirs(path)

    # Prepare Data #
    data = pd.read_csv(args.data_path)[args.column]

    # Pre-processing #
    scaler_1 = StandardScaler()
    scaler_2 = StandardScaler()
    preprocessed_data = pre_processing(data, scaler_1, scaler_2, args.delta)

    X = moving_windows(preprocessed_data, args.ts_dim)
    label = moving_windows(data.to_numpy(), args.ts_dim)

    # Prepare Networks #
    D = Discriminator(args.ts_dim).to(device)
    G = Generator(args.latent_dim, args.ts_dim,
                  args.conditional_dim).to(device)

    # Loss Function #
    if args.criterion == 'l2':
        criterion = nn.MSELoss()
    elif args.criterion == 'wgangp':
        pass
    else:
        raise NotImplementedError

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(0.5, 0.9))
    G_optim = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(0.5, 0.9))

    D_optim_scheduler = get_lr_scheduler(D_optim, args)
    G_optim_scheduler = get_lr_scheduler(G_optim, args)

    # Lists #
    D_losses, G_losses = list(), list()

    # Train #
    print("Training Time Series GAN started with total epoch of {}.".format(
        args.num_epochs))

    for epoch in range(args.num_epochs):

        # Initialize Optimizers #
        G_optim.zero_grad()
        D_optim.zero_grad()

        if args.criterion == 'l2':
            n_critics = 1
        elif args.criterion == 'wgangp':
            n_critics = 5

        #######################
        # Train Discriminator #
        #######################

        for j in range(n_critics):
            series, start_dates = get_samples(X, label, args.batch_size)

            # Data Preparation #
            series = series.to(device)
            noise = torch.randn(args.batch_size, 1, args.latent_dim).to(device)

            # Adversarial Loss using Real Image #
            prob_real = D(series.float())

            if args.criterion == 'l2':
                real_labels = torch.ones(prob_real.size()).to(device)
                D_real_loss = criterion(prob_real, real_labels)

            elif args.criterion == 'wgangp':
                D_real_loss = -torch.mean(prob_real)

            # Adversarial Loss using Fake Image #
            fake_series = G(noise)
            fake_series = torch.cat(
                (series[:, :, :args.conditional_dim].float(),
                 fake_series.float()),
                dim=2)

            prob_fake = D(fake_series.detach())

            if args.criterion == 'l2':
                fake_labels = torch.zeros(prob_fake.size()).to(device)
                D_fake_loss = criterion(prob_fake, fake_labels)

            elif args.criterion == 'wgangp':
                D_fake_loss = torch.mean(prob_fake)
                D_gp_loss = args.lambda_gp * get_gradient_penalty(
                    D, series.float(), fake_series.float(), device)

            # Calculate Total Discriminator Loss #
            D_loss = D_fake_loss + D_real_loss

            if args.criterion == 'wgangp':
                D_loss += args.lambda_gp * D_gp_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

        ###################
        # Train Generator #
        ###################

        # Adversarial Loss #
        fake_series = G(noise)
        fake_series = torch.cat(
            (series[:, :, :args.conditional_dim].float(), fake_series.float()),
            dim=2)
        prob_fake = D(fake_series)

        # Calculate Total Generator Loss #
        if args.criterion == 'l2':
            real_labels = torch.ones(prob_fake.size()).to(device)
            G_loss = criterion(prob_fake, real_labels)

        elif args.criterion == 'wgangp':
            G_loss = -torch.mean(prob_fake)

        # Back Propagation and Update #
        G_loss.backward()
        G_optim.step()

        # Add items to Lists #
        D_losses.append(D_loss.item())
        G_losses.append(G_loss.item())

        ####################
        # Print Statistics #
        ####################

        print("Epochs [{}/{}] | D Loss {:.4f} | G Loss {:.4f}".format(
            epoch + 1, args.num_epochs, np.average(D_losses),
            np.average(G_losses)))

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights and Series #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    args.weights_path,
                    'TimeSeries_Generator_using{}_Epoch_{}.pkl'.format(
                        args.criterion.upper(), epoch + 1)))

            series, fake_series = generate_fake_samples(
                X, label, G, scaler_1, scaler_2, args, device)
            plot_sample(series, fake_series, epoch, args)
            make_csv(series, fake_series, epoch, args)

    print("Training finished.")
コード例 #16
0
print(f.size())

temp_img = torch.cat((f, w), dim=2)

# print(torch.max(f))
# print(torch.min(f))

myimshow(temp_img)

myimshow(normalize(f))

# In[63]:

netG_A2B = Generator(opt.input_nc, opt.output_nc).to(device)
netG_B2A = Generator(opt.output_nc, opt.input_nc).to(device)
netD_A = Discriminator(opt.input_nc).to(device)
netD_B = Discriminator(opt.output_nc).to(device)

netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

# In[64]:

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# In[65]:
コード例 #17
0
class GAN():
    def __init__(self, latent_dim=1, feature_num=1):
        self.latent_dim = latent_dim
        self.feature_num = feature_num 
        self.generator = Generator(self.latent_dim, self.feature_num).build_generator()
        self.discriminator = Discriminator(self.latent_dim).build_discriminator()
        discriminator_optimizer = keras.optimizers.RMSprop(lr=8e-4, clipvalue=2.0, decay=1e-8)
        self.discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')

        self._build_gan()
        

 
    def _build_gan(self):
        self.discriminator.trainable = False
        gan_input = keras.Input(shape=(self.latent_dim, self.feature_num))
        gan_output = self.discriminator(self.generator(gan_input))
        self.gan = keras.models.Model(gan_input, gan_output)
        gan_optimizer = keras.optimizers.RMSprop(lr=4e-4, clipvalue=2.0, decay=1e-8)
        self.gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')
    

    
    def load_data(self, trainx, trainy):
        self.trainX = trainx
        self.trainY = trainy
        print("X_train")
        print(np.array(self.trainX).shape)
        print("Y_train")
        print(np.array(self.trainY).shape)


    
    def _generate_real_samples(self, index, batch_size):
        temp_X = self.trainX[index]
        temp_Y = self.trainY[index]
        # make data in 2D
        temp_X = np.array(temp_X).reshape(self.latent_dim, batch_size)
        temp_Y = np.array(temp_Y).reshape(self.latent_dim, batch_size)
        return temp_X, temp_Y



    def _generate_fake_samples(self, index, batch_size):
        temp_X = self.trainX[index]
        # make data in 3D
        temp_X = np.array(temp_X).reshape(self.latent_dim, batch_size, self.feature_num)
        predictions = self.generator.predict(temp_X)
        return predictions



    # evaluate the discriminator
    def _summarize_performance(self, batch_size):
        # prepare real samples
        X_real, y_real = self._generate_real_samples(batch_size)
        # evaluate discriminator on real examples
        _, acc_real = self.discriminator.evaluate(X_real, y_real, verbose=0)
        # prepare fake examples
        x_fake, y_fake = self._generate_fake_samples(batch_size)
        # evaluate discriminator on fake examples
        _, acc_fake = self.discriminator.evaluate(x_fake, y_fake, verbose=0)
        # summarize discriminator performance
        print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))



    def train(self, epochs, batch_size=1, print_output_every_n_steps=100):
        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------
            # Generate fake and real inputs 
            predictions = self._generate_fake_samples(epoch, batch_size)
            temp_X, temp_Y = self._generate_real_samples(epoch, batch_size)
            input_f = np.concatenate([temp_X, predictions], 0)
            input_r = np.concatenate([temp_X, temp_Y], 0)
            
            input = np.concatenate([[input_r],[input_f]])
            labels = np.concatenate([[np.ones((2, 1))], [np.zeros((2, 1))]])

            d_loss = self.discriminator.train_on_batch(input, labels)
            # ---------------------
            #  Train Generator
            # ---------------------
            valid_y = np.ones((batch_size, 1))
            temp_X = temp_X.reshape(self.latent_dim, batch_size, self.feature_num)
            g_loss = self.gan.train_on_batch(temp_Y, valid_y)

            #if epoch % print_output_every_n_steps == 0:
            #   self._summarize_performance(batch_size)
        
        return self.generator
コード例 #18
0
ファイル: train.py プロジェクト: MasoumehVahedi/GANs-Model
                          transforms.ToTensor(),
                          # [0.5 for _ in range(IMG_CHANNEL)] = (0.5, 0.5, 0.5)
                          transforms.Normalize(
                              [0.5 for _ in range(IMG_CHANNEL)], [0.5 for _ in range(IMG_CHANNEL)]
                          )
                      ]))
# Create the dataloader
dataloader = DataLoader(dataset,
                        batch_size = batch_size,
                        shuffle = True)

"""Next, we initialize the generator, discriminator, and optimizers """

gen = Generator(z_dim).to(device)
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# Here, we want to initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def initialize_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
        nn.init.normal_(m.bias, 0)

gen = gen.apply(initialize_weights)
disc = disc.apply(initialize_weights)

print(gen)
コード例 #19
0
ファイル: train.py プロジェクト: KC900201/Python_Learning
transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)],
        [0.5 for _ in range(CHANNELS_IMG)],
    )
])

dataset = datasets.MNIST(root="/data/",
                         train=True,
                         transform=transforms,
                         download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

initialize_weights(gen)
initialize_weights(disc)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()
コード例 #20
0
    n_continuous = args.n_continuous
    max_epochs = args.max_epochs
    batch_size = args.batch_size
    out_generator_filename = args.out_generator_filename

    # Prepare the training data
    train, _ = datasets.get_mnist(withlabel=False, ndim=2)
    train_size = train.shape[0]
    im_shape = train.shape[1:]

    # Prepare the models
    generator = Generator(n_z + n_categorical + n_continuous, im_shape)
    generator_optimizer = O.Adam(alpha=1e-3, beta1=0.5)
    generator_optimizer.setup(generator)

    discriminator = Discriminator(im_shape, n_categorical, n_continuous)
    discriminator_optimizer = O.Adam(alpha=2e-4, beta1=0.5)
    discriminator_optimizer.setup(discriminator)

    if gpu >= 0:
        cuda.check_cuda_available()
        cuda.get_device(gpu).use()
        generator.to_gpu()
        discriminator.to_gpu()
        xp = cuda.cupy
    else:
        xp = np

    for epoch in range(max_epochs):
        generator_epoch_loss = np.float32(0)
        discriminator_epoch_loss = np.float32(0)
コード例 #21
0
opt = parser.parse_args()
print(opt)

random.seed(opt.seed)
torch.manual_seed(opt.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(opt.seed)

# Networks
if opt.upsample == 'ori':
    netG_A2B = Generator_ori(opt.input_nc, opt.output_nc)
    netG_B2A = Generator_ori(opt.output_nc, opt.input_nc)
else:
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.output_nc)

netG_A2B.cuda()
netG_B2A.cuda()
netD_A.cuda()
netD_B.cuda()

netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

torch.save(netG_A2B.state_dict(), "initial_weights/netG_A2B_seed_{}.pth.tar".format(opt.seed))
torch.save(netG_B2A.state_dict(), "initial_weights/netG_B2A_seed_{}.pth.tar".format(opt.seed))
torch.save(netD_A.state_dict(), "initial_weights/netD_A_seed_{}.pth.tar".format(opt.seed))
コード例 #22
0
    print(img.shape)
    img = np.transpose(img, (1, 2, 0))
    img = ((img + 1) * 255 / (2)).astype(
        np.uint8)  # rescale to pixel range (0-255)
    img = Image.fromarray(img, 'RGB')
    # print(img)
    img.show()


if __name__ == "__main__":

    z_size = 100
    samples = []
    sample_size = 1

    D = Discriminator()
    G = Generator(z_size)

    dir = str(pathlib.Path().absolute()) + '/'

    G.load_state_dict(torch.load(dir + 'checkpoint_G.pth', map_location='cpu'))
    D.load_state_dict(torch.load(dir + 'checkpoint_D.pth', map_location='cpu'))

    G.eval()

    fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
    fixed_z = torch.from_numpy(fixed_z).float()
    print(fixed_z.shape)
    sample = G(fixed_z)
    _ = view_samples(sample)
コード例 #23
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader(purpose='train', batch_size=config.batch_size)
    test_horse_loader, test_zebra_loader = get_horse2zebra_loader(purpose='test', batch_size=config.val_batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Prepare Networks #
    D_A = Discriminator()
    D_B = Discriminator()
    G_A2B = Generator()
    G_B2A = Generator()

    networks = [D_A, D_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()
    criterion_Identity = nn.L1Loss()

    # Optimizers #
    D_A_optim = torch.optim.Adam(D_A.parameters(), lr=config.lr, betas=(0.5, 0.999))
    D_B_optim = torch.optim.Adam(D_B.parameters(), lr=config.lr, betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999))

    D_A_optim_scheduler = get_lr_scheduler(D_A_optim)
    D_B_optim_scheduler = get_lr_scheduler(D_B_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses_A, D_losses_B, G_losses = [], [], []

    # Training #
    print("Training CycleGAN started with total epoch of {}.".format(config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (horse, zebra) in enumerate(zip(train_horse_loader, train_zebra_loader)):

            # Data Preparation #
            real_A = horse.to(device)
            real_B = zebra.to(device)

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_A_optim.zero_grad()
            D_B_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss #
            fake_A = G_B2A(real_B)
            prob_fake_A = D_A(fake_A)
            real_labels = torch.ones(prob_fake_A.size()).to(device)
            G_mse_loss_B2A = criterion_Adversarial(prob_fake_A, real_labels)

            fake_B = G_A2B(real_A)
            prob_fake_B = D_B(fake_B)
            real_labels = torch.ones(prob_fake_B.size()).to(device)
            G_mse_loss_A2B = criterion_Adversarial(prob_fake_B, real_labels)

            # Identity Loss #
            identity_A = G_B2A(real_A)
            G_identity_loss_A = config.lambda_identity * criterion_Identity(identity_A, real_A)

            identity_B = G_A2B(real_B)
            G_identity_loss_B = config.lambda_identity * criterion_Identity(identity_B, real_B)

            # Cycle Loss #
            reconstructed_A = G_B2A(fake_B)
            G_cycle_loss_ABA = config.lambda_cycle * criterion_Cycle(reconstructed_A, real_A)

            reconstructed_B = G_A2B(fake_A)
            G_cycle_loss_BAB = config.lambda_cycle * criterion_Cycle(reconstructed_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_mse_loss_B2A + G_mse_loss_A2B + G_identity_loss_A + G_identity_loss_B + G_cycle_loss_ABA + G_cycle_loss_BAB

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            ## Train Discriminator A ##
            # Real Loss #
            prob_real_A = D_A(real_A)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Fake Loss #
            fake_A = G_B2A(real_B)
            prob_fake_A = D_A(fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            # Calculate Total Discriminator A Loss #
            D_loss_A = config.lambda_identity * (D_real_loss_A + D_fake_loss_A).mean()

            # Back propagation and Update #
            D_loss_A.backward()
            D_A_optim.step()

            ## Train Discriminator B ##
            # Real Loss #
            prob_real_B = D_B(real_B)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Fake Loss #
            fake_B = G_A2B(real_A)
            prob_fake_B = D_B(fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            loss_fake_B = criterion_Adversarial(prob_fake_B, fake_labels)

            # Calculate Total Discriminator B Loss #
            D_loss_B = config.lambda_identity * (loss_real_B + loss_fake_B).mean()

            # Back propagation and Update #
            D_loss_B.backward()
            D_B_optim.step()

            # Add items to Lists #
            D_losses_A.append(D_loss_A.item())
            D_losses_B.append(D_loss_B.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i+1) % config.print_every == 0:
                print("CycleGAN | Epoch [{}/{}] | Iterations [{}/{}] | D_A Loss {:.4f} | D_B Loss {:.4f} | G Loss {:.4f}"
                      .format(epoch + 1, config.num_epochs, i + 1, total_batch, np.average(D_losses_A), np.average(D_losses_B), np.average(G_losses)))

                # Save Sample Images #
                sample_images(test_horse_loader, test_zebra_loader, G_A2B, G_B2A, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_A_optim_scheduler.step()
        D_B_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch+1) % config.save_every == 0:
            torch.save(G_A2B.state_dict(), os.path.join(config.weights_path, 'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(G_B2A.state_dict(), os.path.join(config.weights_path, 'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch+1)))

    # Make a GIF file #
    make_gifs_train("CycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses_A, D_losses_B, G_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
コード例 #24
0
def TrainSourceModel(options):
    #defining the models
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    options['device'] = device
    cEnc = ConvEncoder(latentDimension=options['latentD']).to(device)
    disc = Discriminator(latentDimension=options['latentD']).to(device)
    classify = Classifier(latentDimension=options['latentD']).to(device)

    # defining loss functions
    MSELoss = nn.MSELoss().to(device)  # mean squared error loss
    BCELoss = nn.BCELoss().to(device)  # binary cross-entropy
    NLLLoss = nn.NLLLoss().to(device)  # negative-log likelihood
    CELoss = nn.CrossEntropyLoss().to(device)  # cross entropy loss

    # loss log data
    lossData_a, lossData_b = [], []

    optimizer_a = optim.Adam(itertools.chain(cEnc.parameters(),
                                             classify.parameters()),
                             lr=options['lrA'])
    optimizer_b = optim.Adam(itertools.chain(disc.parameters()),
                             lr=options['lrB'])

    logText = ""
    train_loader = options['trainLoader']

    for epochIdx in range(options['epochs']):

        # defining the data loaders
        _ones = Variable(torch.FloatTensor(options['batchSize'], 1).fill_(1.0),
                         requires_grad=False).to(device)
        _zeros = Variable(torch.FloatTensor(options['batchSize'],
                                            1).fill_(0.0),
                          requires_grad=False).to(device)

        for batchIdx, (batchData, batchLabels) in enumerate(train_loader):

            # data setup
            _batchSize = batchData.shape[0]
            batchData = batchData.to(device)
            batchLabels = batchLabels.to(device)
            latentSpaceSample_A, latentSpaceClasses_A = getSampleFromLatentSpace(
                [_batchSize, options['latentD']], options)

            # convAutoencoder Pass
            encodedDataBatch = cEnc(batchData)
            classesPrediction = classify(encodedDataBatch)

            #sample pass
            sampleClassesPrediction = classify(latentSpaceSample_A)

            # optimzation step I
            optimizer_a.zero_grad()
            #--- first loss function
            loss_a = CELoss(classesPrediction, batchLabels)+\
                     options["lrA_DiscCoeff"]*BCELoss(disc(encodedDataBatch),_ones[:_batchSize])+\
                     CELoss(sampleClassesPrediction,latentSpaceClasses_A)

            loss_a.backward()
            optimizer_a.step()
            lossData_a.append(loss_a.data.item())

            # optimization step II
            #---train the discriminator, 1/0 is real/fake data
            optimizer_b.zero_grad()
            latentSpaceSample_B, latentSpaceClasses_B = getSampleFromLatentSpace(
                [_batchSize, options['latentD']], options)
            discDataInput = Variable(encodedDataBatch.view(
                _batchSize, options['latentD']).cpu().data,
                                     requires_grad=False).to(device)

            #---second loss function
            loss_b=BCELoss(disc(discDataInput),_zeros[:_batchSize])+\
                   BCELoss(disc(latentSpaceSample_B),_ones[:_batchSize])
            loss_b.backward()
            optimizer_b.step()
            lossData_b.append(loss_b.data.item())

            # print running accuracy on this batchData
            #predictedLabeles = classesPrediction.argmax(1)
            # print("{}/{} accuracy, and loss={}".format(float(torch.sum(predictedLabeles == batchLabels).data.cpu().numpy()),_batchSize,loss_a.data.item()))

        ####
        #### End of an epoch
        ####

        logText += validateModel(epochIdx,
                                 options,
                                 models=[cEnc, classify, disc])
        # end of an epoch - CHECK ACCURACY ON TEST SET

    outputs = {
        'lossA': lossData_a,
        'lossB': lossData_b,
        'encoder': cEnc,
        'disc': disc,
        'classifier': classify,
        'logText': logText
    }
    return outputs
def train(n_channels=3,
          resolution=32,
          z_dim=128,
          n_labels=0,
          lr=1e-3,
          e_drift=1e-3,
          wgp_target=750,
          initial_resolution=4,
          total_kimg=25000,
          training_kimg=500,
          transition_kimg=500,
          iters_per_checkpoint=500,
          n_checkpoint_images=16,
          glob_str='cifar10',
          out_dir='cifar10'):

    # instantiate logger
    logger = SummaryWriter(out_dir)

    # load data
    batch_size = MINIBATCH_OVERWRITES[0]
    train_iterator = iterate_minibatches(glob_str, batch_size, resolution)

    # build models
    G = Generator(n_channels, resolution, z_dim, n_labels)
    D = Discriminator(n_channels, resolution, n_labels)

    G_train, D_train = GAN(G, D, z_dim, n_labels, resolution, n_channels)

    D_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
    G_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

    # define loss functions
    D_loss = [loss_mean, loss_gradient_penalty, 'mse']
    G_loss = [loss_wasserstein]

    # compile graphs used during training
    G.compile(G_opt, loss=loss_wasserstein)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=[1, GP_WEIGHT, e_drift])

    # for computing the loss
    ones = np.ones((batch_size, 1), dtype=np.float32)
    zeros = ones * 0.0

    # fix a z vector for training evaluation
    z_fixed = np.random.normal(0, 1, size=(n_checkpoint_images, z_dim))

    # vars
    resolution_log2 = int(np.log2(resolution))
    starting_block = resolution_log2
    starting_block -= np.floor(np.log2(initial_resolution))
    cur_block = starting_block
    cur_nimg = 0

    # compute duration of each phase and use proxy to update minibatch size
    phase_kdur = training_kimg + transition_kimg
    phase_idx_prev = 0

    # offset variable for transitioning between blocks
    offset = 0
    i = 0
    while cur_nimg < total_kimg * 1000:
        # block processing
        kimg = cur_nimg / 1000.0
        phase_idx = int(np.floor((kimg + transition_kimg) / phase_kdur))
        phase_idx = max(phase_idx, 0.0)
        phase_kimg = phase_idx * phase_kdur

        # update batch size and ones vector if we switched phases
        if phase_idx_prev < phase_idx:
            batch_size = MINIBATCH_OVERWRITES[phase_idx]
            train_iterator = iterate_minibatches(glob_str, batch_size)
            ones = np.ones((batch_size, 1), dtype=np.float32)
            zeros = ones * 0.0
            phase_idx_prev = phase_idx

        # possibly gradually update current level of detail
        if transition_kimg > 0 and phase_idx > 0:
            offset = (kimg + transition_kimg - phase_kimg) / transition_kimg
            offset = min(offset, 1.0)
            offset = offset + phase_idx - 1
            cur_block = max(starting_block - offset, 0.0)

        # update level of detail
        K.set_value(G_train.cur_block, np.float32(cur_block))
        K.set_value(D_train.cur_block, np.float32(cur_block))

        # train D
        for j in range(N_CRITIC_ITERS):
            z = np.random.normal(0, 1, size=(batch_size, z_dim))
            real_batch = next(train_iterator)
            fake_batch = G.predict_on_batch([z])
            interpolated_batch = get_interpolated_images(
                real_batch, fake_batch)
            losses_d = D_train.train_on_batch(
                [real_batch, fake_batch, interpolated_batch],
                [ones, ones * wgp_target, zeros])
            cur_nimg += batch_size

        # train G
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        loss_g = G_train.train_on_batch(z, -1 * ones)

        logger.add_scalar("cur_block", cur_block, i)
        logger.add_scalar("learning_rate", lr, i)
        logger.add_scalar("batch_size", z.shape[0], i)
        print("iter", i, "cur_block", cur_block, "lr", lr, "kimg", kimg,
              "losses_d", losses_d, "loss_g", loss_g)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_images = G.predict(z_fixed)
            # log fake images
            log_images(fake_images, 'fake', i, logger, fake_images.shape[1],
                       fake_images.shape[2], int(np.sqrt(n_checkpoint_images)))

            # plot real images for reference
            log_images(real_batch[:n_checkpoint_images], 'real', i, logger,
                       real_batch.shape[1], real_batch.shape[2],
                       int(np.sqrt(n_checkpoint_images)))

            # save the model to eventually resume training or do inference
            save_model(G, out_dir + "/model.json", out_dir + "/model.h5")

        log_losses(losses_d, loss_g, i, logger)
        i += 1
コード例 #26
0
    type=int,
    default=8,
    help='number of cpu threads to use during batch generation')
opt = parser.parse_args()
print(opt)

if torch.cuda.is_available() and not opt.cuda:
    print(
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

###### Definition of variables ######
# Networks
netG_A2BC = Generator(opt.input_nc, opt.output_nc)
netG_BC2A = Generator(opt.output_nc, opt.input_nc)
netD_A = Discriminator(opt.input_nc)
netD_B = Discriminator(opt.output_nc)
netD_C = Discriminator(opt.output_nc)

if opt.cuda:
    netG_A2BC.cuda()
    netG_BC2A.cuda()
    netD_A.cuda()
    netD_B.cuda()
    netD_C.cuda()

netG_A2BC.apply(weights_init_normal)
netG_BC2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)
netD_C.apply(weights_init_normal)
コード例 #27
0
ファイル: gan.py プロジェクト: omikader/flag-gan
                epoch, batch_idx, d_error.data[0], g_error.data[0]))


def test(fixed_noise, epoch):
    G.eval()
    # Run noise through generator and reshape output vector to 4x16x32 to match
    # flag size for display purposes. Convert to RGBA PIL image and display
    sample = G(fixed_noise).data[0].view(4, 16, 32)
    img = transforms.functional.to_pil_image(sample, mode='RGBA')
    imgplot = plt.imshow(img)
    plt.title('Epoch {}'.format(epoch))
    plt.show()


if __name__ == '__main__':
    G, D = Generator(), Discriminator()
    loss = nn.BCELoss()
    g_optim = optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    d_optim = optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

    flag_loader = get_flag_loader()
    noise = Variable(torch.randn(1, 100, 1, 1))

    for epoch in range(1, args.epochs + 1):
        # Train Model
        train(data_loader=flag_loader, epoch=epoch)

        # Test Generator
        if epoch % args.test_interval == 0:
            test(fixed_noise=noise, epoch=epoch)
コード例 #28
0
def train():
    from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm
    import lpips

    from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU
    from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    real_features = None
    inception = load_patched_inception_v3().cuda()
    inception.eval()

    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    saved_image_folder = saved_model_folder = None
    log_file_path = None
    if saved_image_folder is None:
        saved_image_folder, saved_model_folder = make_folders(
            SAVE_FOLDER, 'GAN_' + TRIAL_NAME)
        log_file_path = saved_image_folder + '/../gan_log.txt'
        log_file = open(log_file_path, 'w')
        log_file.close()

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_GAN,
                                 rand_crop=True)
    print('the dataset contains %d images.' % len(dataset))
    dataloader = iter(
        DataLoader(dataset,
                   BATCH_SIZE_GAN,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=DATALOADER_WORKERS,
                   pin_memory=True))

    from datasets import ImageFolder
    from datasets import trans_maker_augment as trans_maker

    dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512))
    dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512))

    net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS)

    if PRETRAINED_AE_PATH is None:
        PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE
    else:
        from config import PRETRAINED_AE_ITER
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER

    net_ae.load_state_dicts(PRETRAINED_AE_PATH)
    net_ae.cuda()
    net_ae.eval()

    RefineGenerator = None
    if DATA_NAME == 'celeba':
        from models import RefineGenerator_face as RefineGenerator
    elif DATA_NAME == 'art' or DATA_NAME == 'shoe':
        from models import RefineGenerator_art as RefineGenerator
    net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda()
    net_id = Discriminator(nc=3).cuda(
    )  # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024

    if MULTI_GPU:
        net_ae = nn.DataParallel(net_ae)
        net_ig = nn.DataParallel(net_ig)
        net_id = nn.DataParallel(net_id)

    net_ig_ema = copy_G_params(net_ig)

    opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999))

    if GAN_CKECKPOINT is not None:
        ckpt = torch.load(GAN_CKECKPOINT)
        net_ig.load_state_dict(ckpt['ig'])
        net_id.load_state_dict(ckpt['id'])
        net_ig_ema = ckpt['ig_ema']
        opt_ig.load_state_dict(ckpt['opt_ig'])
        opt_id.load_state_dict(ckpt['opt_id'])

    ## create a log file
    losses_g_img = AverageMeter()
    losses_d_img = AverageMeter()
    losses_mse = AverageMeter()
    losses_rec_s = AverageMeter()

    losses_rec_ae = AverageMeter()

    fixed_skt = fixed_rgb = fixed_perm = None

    fid = [[0, 0]]

    for epoch in range(EPOCH_GAN):
        for iteration in tqdm(range(10000)):
            rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

            rgb_img = rgb_img.cuda()

            rd = random.randint(0, 3)
            if rd == 0:
                skt_img = skt_img_1.cuda()
            elif rd == 1:
                skt_img = skt_img_2.cuda()
            else:
                skt_img = skt_img_3.cuda()

            if iteration == 0:
                fixed_skt = skt_img_3[:8].clone().cuda()
                fixed_rgb = rgb_img[:8].clone()
                fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda')

            ### 1. train D
            gimg_ae, style_feats = net_ae(skt_img, rgb_img)
            g_image = net_ig(gimg_ae, style_feats)

            pred_r = net_id(rgb_img)
            pred_f = net_id(g_image.detach())

            loss_d = d_hinge_loss(pred_r, pred_f)

            net_id.zero_grad()
            loss_d.backward()
            opt_id.step()

            loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(
                gimg_ae, rgb_img)
            losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN)

            ### 2. train G
            pred_g = net_id(g_image)
            loss_g = g_hinge_loss(pred_g)

            if DATA_NAME == 'shoe':
                loss_mse = 10 * (F.l1_loss(g_image, rgb_img) +
                                 F.mse_loss(g_image, rgb_img))
            else:
                loss_mse = 10 * percept(
                    F.adaptive_avg_pool2d(g_image, output_size=256),
                    F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
            losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)

            loss_all = loss_g + loss_mse

            if DATA_NAME == 'shoe':
                ### the grey image reconstruction
                perm = true_randperm(BATCH_SIZE_GAN)
                img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm])

                gimg_grey = net_ig(img_ae_perm, style_feats_perm)
                gimg_grey = gimg_grey.mean(dim=1, keepdim=True)
                real_grey = rgb_img.mean(dim=1, keepdim=True)
                loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
                loss_all += 10 * loss_rec_grey

            net_ig.zero_grad()
            loss_all.backward()
            opt_ig.step()

            for p, avg_p in zip(net_ig.parameters(), net_ig_ema):
                avg_p.mul_(0.999).add_(p.data, alpha=0.001)

            ### 3. logging
            losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN)
            losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN)

            if iteration % SAVE_IMAGE_INTERVAL == 0:  #show the current images
                with torch.no_grad():

                    backup_para_g = copy_G_params(net_ig)
                    load_params(net_ig, net_ig_ema)

                    gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb)
                    gmatch = net_ig(gimg_ae, style_feats)

                    gimg_ae_perm, style_feats = net_ae(fixed_skt,
                                                       fixed_rgb[fixed_perm])
                    gmismatch = net_ig(gimg_ae_perm, style_feats)

                    gimg = torch.cat([
                        F.interpolate(fixed_rgb, IM_SIZE_GAN),
                        F.interpolate(fixed_skt.repeat(1, 3, 1, 1),
                                      IM_SIZE_GAN), gmatch,
                        F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch,
                        F.interpolate(gimg_ae_perm, IM_SIZE_GAN)
                    ])

                    vutils.save_image(
                        gimg,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg',
                        normalize=True,
                        range=(-1, 1))
                    del gimg

                    make_matrix(
                        dataset_rgb, dataset_skt, net_ae, net_ig, 5,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg'
                    )

                    load_params(net_ig, backup_para_g)

            if iteration % LOG_INTERVAL == 0:
                log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f}  D: {losses_d_img.avg:.4f}  MSE: {losses_mse.avg:.4f}  Rec: {losses_rec_s.avg:.5f}  FID: {fid:.4f}'.format(
                    epoch,
                    iteration,
                    losses_g_img=losses_g_img,
                    losses_d_img=losses_d_img,
                    losses_mse=losses_mse,
                    losses_rec_s=losses_rec_s,
                    fid=fid[-1][0])

                print(log_msg)
                print('%.5f' % (losses_rec_ae.avg))

                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_file.write(log_msg + '\n')
                    log_file.close()

                losses_g_img.reset()
                losses_d_img.reset()
                losses_mse.reset()
                losses_rec_s.reset()
                losses_rec_ae.reset()

            if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000:
                print('Saving history model')
                torch.save(
                    {
                        'ig': net_ig.state_dict(),
                        'id': net_id.state_dict(),
                        'ae': net_ae.state_dict(),
                        'ig_ema': net_ig_ema,
                        'opt_ig': opt_ig.state_dict(),
                        'opt_id': opt_id.state_dict(),
                    }, '%s/%d.pth' % (saved_model_folder, epoch))

            if iteration % FID_INTERVAL == 0 and iteration > 1:
                print("calculating FID ...")
                fid_batch_images = FID_BATCH_NBR
                if real_features is None:
                    if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)):
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))
                    else:
                        real_features = extract_feature_from_generator_fn(
                            real_image_loader(dataloader,
                                              n_batches=fid_batch_images),
                            inception)
                        real_mean = np.mean(real_features, 0)
                        real_cov = np.cov(real_features, rowvar=False)
                        pickle.dump(
                            {
                                'feats': real_features,
                                'mean': real_mean,
                                'cov': real_cov
                            }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb'))
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))

                sample_features = extract_feature_from_generator_fn(
                    image_generator(dataset,
                                    net_ae,
                                    net_ig,
                                    n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid = calc_fid(sample_features,
                                   real_mean=real_features['mean'],
                                   real_cov=real_features['cov'])
                sample_features_perm = extract_feature_from_generator_fn(
                    image_generator_perm(dataset,
                                         net_ae,
                                         net_ig,
                                         n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid_perm = calc_fid(sample_features_perm,
                                        real_mean=real_features['mean'],
                                        real_cov=real_features['cov'])

                fid.append([cur_fid, cur_fid_perm])
                print('fid:', fid)
                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1])
                    log_file.write(log_msg + '\n')
                    log_file.close()
コード例 #29
0
if DISTRIBUTED:
    torch.cuda.set_device(LOCAL_RANK)
    torch.distributed.init_process_group(backend="nccl", init_method="env://")
    synchronize()

LATENT = 512
N_MLP = 8
START_ITER = 0

generator = Generator(SIZEX,
                      SIZEY,
                      LATENT,
                      N_MLP,
                      channel_multiplier=CHANNEL_MULTIPLIER).to(device)
discriminator = Discriminator(SIZEX,
                              SIZEY,
                              channel_multiplier=CHANNEL_MULTIPLIER).to(device)

from torchsummary import summary

summary(generator, (1, LATENT))

torch.cuda.empty_cache()

# average of the weights of generator to visualize each epochs
g_ema = Generator(LATENT, N_MLP, channel_multiplier=CHANNEL_MULTIPLIER).cuda()

# eval mode
g_ema.eval()

# slowly move through each generator steps
コード例 #30
0
        torch.save(netD_A.state_dict(),
                   os.path.join(args.output_dir, 'netD_A.pth'))
        torch.save(netD_B.state_dict(),
                   os.path.join(args.output_dir, 'netD_B.pth'))
        torch.save(netGaze.state_dict(),
                   os.path.join(args.output_dir, 'netGaze.pth'))
        plot_confusion_matrix(target_all, pred_all, activity_classes)

    return val_accuracy


if __name__ == '__main__':
    # networks
    netG_A2B = Generator(args.nc, args.nc)
    netG_B2A = Generator(args.nc, args.nc)
    netD_A = Discriminator(args.nc)
    netD_B = Discriminator(args.nc)
    netGaze = SqueezeNet(args.version)

    if args.snapshot_dir is not None:
        if os.path.exists(os.path.join(args.snapshot_dir, 'netG_A2B.pth')):
            netG_A2B.load_state_dict(torch.load(
                os.path.join(args.snapshot_dir, 'netG_A2B.pth')),
                                     strict=False)
        if os.path.exists(os.path.join(args.snapshot_dir, 'netG_B2A.pth')):
            netG_B2A.load_state_dict(torch.load(
                os.path.join(args.snapshot_dir, 'netG_B2A.pth')),
                                     strict=False)
        if os.path.exists(os.path.join(args.snapshot_dir, 'netD_A.pth')):
            netD_A.load_state_dict(torch.load(
                os.path.join(args.snapshot_dir, 'netD_A.pth')),
コード例 #31
0
parser.add_argument("--epoch", default=200, type=int)
parser.add_argument("--iterate", default=10, type=int)
parser.add_argument("--lambda1", default=100, type=int)

args = parser.parse_args()

batchsize = 1
input_channel = 3
output_channel = 3
input_height = input_width = output_height = output_width = 256

input_data = Dataset(data_start = 1, data_end = 299)
train_len = input_data.len()

generator_G = Generator(input_channel, output_channel)
discriminator_D = Discriminator(input_channel, output_channel)
weights_init(generator_G)
weights_init(discriminator_D)

generator_G.cuda()
discriminator_D.cuda()

loss_L1 = nn.L1Loss().cuda()
loss_binaryCrossEntropy = nn.BCELoss().cuda()

optimizer_G= torch.optim.Adam(generator_G.parameters(), lr= 0.0002, betas=(0.5, 0.999), weight_decay= 0.00001)
optimizer_D= torch.optim.Adam(discriminator_D.parameters(), lr= 0.0002, betas=(0.5, 0.999), weight_decay=0.00001)

input_x_np = np.zeros((batchsize, input_channel, input_height, input_width)).astype(np.float32)
input_real_np = np.zeros((batchsize, output_channel, output_height, output_width)).astype(np.float32)