Ejemplo n.º 1
0
def train(netD, netG, criterion, optimizerG, optimizerD):
    for epoch in range(Opt.epoch):
        avg_lossD = 0
        avg_lossG = 0
        with open(os.path.join(Opt.root, 'logs.txt', 'a')) as file:
            for i, (data, _) in enumerate(train_loader):
                # Update D network
                mini_batch = data.shape[0]
                # train with real
                input = Variable(data.cuda())   # image input
                real_label = Variable(torch.ones(mini_batch).cuda())
                output = netD(input)
                D_real_loss = criterion(output, real_label)
                # train with fake
                noise = Variable(torch.randn(mini_batch, Opt.nz).view(-1, Opt.nz, 1, 1).cuda())
                fake = netG(noise)
                fake_label = Variable(torch.zeros(mini_batch).cuda())
                output = netD(fake.detach())    # detach to avoid training G on these labels
                G_real_loss = criterion(output, fake_label)
                D_loss = D_real_loss + G_real_loss
                netD.zero_grad()
                D_loss.backward()
                if Opt.which_pc == 0:
                    avg_lossD += D_loss.item()
                else:
                    avg_lossD += D_loss.data[0]
                optimizerD.step()
                # Update G network
                output = netD(fake)
                G_loss = criterion(output, real_label)
                if Opt.which_pc == 0:
                    avg_lossG += G_loss.item()
                else:
                    avg_lossG += G_loss.data[0]
                netG.zero_grad()
                G_loss.backward()
                optimizerG.step()

                print('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f'
                      % (epoch + 1, Opt.epoch, i + 1, len(train_loader), D_loss.data[0], G_loss.data[0]))
            avg_lossD /= i
            avg_lossG /= i
            print('epoch: ' + str(epoch) + ', G_loss: ' + str(avg_lossG) + ', D_loss: ' + str(avg_lossD))
            file.write('epoch: ' + str(epoch) + ', G_loss: ' + str(avg_lossG) + ', D_loss: ' + str(avg_lossD) + '\n')

        # save generated images
        fixed_pred = netG(fixed_noise)
        vutils.save_image(fixed_pred.data, os.path.join(Opt.results_dir,'img'+str(epoch)+'.png'), nrow=10, scale_each=True)

        if epoch % 200 == 0:
            if Opt.save_model:
                torch.save(netD.state_dict(), os.path.join(Opt.checkpoint_dir, 'netD-01.pt'))
                torch.save(netG.state_dict(), os.path.join(Opt.checkpoint_dir, 'netG-01.pt'))
Ejemplo n.º 2
0
    def __init__(self, args):
        self.args = args
        self.data_loader = get_data_loader(args)
        netG = model.netG(args)
        netD = model.netD(args)
        if args.use_cuda:
            netG = netG.cuda()
            netD = netD.cuda()
        self.netG = netG
        self.netD = netD
        self.optimizer_D = optim.Adam(self.netD.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, 0.999))
        self.optimizer_G = optim.Adam(self.netG.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, 0.999))
        self.criterion = nn.BCELoss()
        self.real_label = 1
        self.fake_label = 0

        fixed_noise_np = np.random.normal(0.0, 1.0,
                                          size=(args.batch_size,
                                                args.nz,
                                                1,
                                                1))
        fixed_noise = torch.from_numpy(fixed_noise_np).type(torch.FloatTensor)
        if args.use_cuda:
            fixed_noise = fixed_noise.cuda()
        self.fixed_noise = Variable(fixed_noise)
Ejemplo n.º 3
0
    def __init__(self, args):
        np.random.seed(args.seed)

        self.args = args

        self.logger = logger.Logger(args.output_dir)
        self.args.logger = self.logger

        current_commit_hash =\
            subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
        self.logger.log('current git commit hash: %s' % current_commit_hash)

        print('load vec')
        source_vecs, source_dico =\
            utils.load_word_vec_list(args.source_vec_file, args.source_lang)
        target_vecs, target_dico =\
            utils.load_word_vec_list(args.target_vec_file, args.target_lang)

        self.src_dico = source_dico
        self.tgt_dico = target_dico
        args.src_dico = source_dico
        args.tgt_dico = target_dico

        src_embed, tgt_embed =\
            utils.get_embeds_from_numpy(source_vecs, target_vecs)
        if args.use_cuda:
            self.src_embed = src_embed.cuda()
            self.tgt_embed = tgt_embed.cuda()
        else:
            self.src_embed = src_embed
            self.tgt_embed = tgt_embed

        print('setting models')
        netD = model.netD(self.args)
        netG = model.netG()
        netG.W.weight.data.copy_(torch.diag(torch.ones(300)))
        if args.multi_gpu:
            netD = nn.DataParallel(netD)
            netG = nn.DataParallel(netG)
        if args.use_cuda:
            netD = netD.cuda()
            netG = netG.cuda()
        self.netD = netD
        self.netG = netG
        self.optimizer_D = optim.Adam(self.netD.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, 0.999))
        self.optimizer_G = optim.Adam(self.netG.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, 0.999))
        self.criterion = nn.BCELoss()
        self.prefix = os.path.basename(args.output_dir)

        self.evaluator = Evaluator(self)
Ejemplo n.º 4
0
def main(args):

    # load model
    print('loading model')
    netG = model.netG()
    load_path = os.path.join(args.load_dir, 'netG_state.pth')
    netG.load_state_dict(torch.load(load_path))

    print('preparing source word vectors')
    source_w2v = KeyedVectors.load_word2vec_format(args.source_vec_file,
                                                   binary=False)
    print('preparing target word vectors')
    target_w2v = KeyedVectors.load_word2vec_format(args.target_vec_file,
                                                   binary=False)
    target_vocab = list(target_w2v.vocab.keys())

    # get source word vec
    print('')
    source_vec = source_w2v.get_vector(args.source_word)
    source_vec = Variable(torch.from_numpy(source_vec).type(torch.FloatTensor))

    # conver source vector
    print('transfering source word vec')
    transfered_source_vec = netG(source_vec).data

    # search most similar
    print('searching for most similar words')
    target_vocab_n, embed_n = target_w2v.vectors.shape
    distances = []
    for idx in range(target_vocab_n):
        distance = 1 - cosine(target_w2v.vectors[idx, :],
                              transfered_source_vec)
        distances.append(distance)
    top_n_indexes = list(reversed(np.argsort(distances)[-5:]))
    for i in top_n_indexes:
        print(target_vocab[i])

    return transfered_source_vec
Ejemplo n.º 5
0
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.utils import save_image
from tqdm import tqdm
import time
from model import netD, netG
from config import opt
from Myutils import dataloader

netd = netD(opt)
netg = netG(opt)

ds, dl = dataloader()

device = torch.device('cuda:0') if opt.use_gpu else torch.device('cpu')
if opt.d_save_path:
    netd.load_state_dict(torch.load(opt.d_save_path))
    print("net_D loads weight successfully......")
    print('___' * 10)
if opt.g_save_path:
    netg.load_state_dict(torch.load(opt.g_save_path))
    print("net_G loads weight successfully......")
    print('___' * 10)
netd.to(device)
netg.to(device)
optm_g = torch.optim.Adam(netg.parameters(),
                          lr=opt.lr,
                          betas=(opt.beta1, 0.999))
optm_d = torch.optim.Adam(netd.parameters(),
                          lr=opt.lr,
Ejemplo n.º 6
0
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))


nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
if opt.dataset == 'mnist':
    nc = 1
    nb_label = 10
else:
    nc = 3
    nb_label = 10

netG = model.netG(nz, ngf, nc)

if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print(netG)

netD = model.netD(ndf, nc, nb_label)

if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))
print(netD)

s_criterion = nn.BCELoss()
c_criterion = nn.NLLLoss()

input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
Ejemplo n.º 7
0
            file.write('epoch: ' + str(epoch) + ', G_loss: ' + str(avg_lossG) + ', D_loss: ' + str(avg_lossD) + '\n')

        # save generated images
        fixed_pred = netG(fixed_noise)
        vutils.save_image(fixed_pred.data, os.path.join(Opt.results_dir,'img'+str(epoch)+'.png'), nrow=10, scale_each=True)

        if epoch % 200 == 0:
            if Opt.save_model:
                torch.save(netD.state_dict(), os.path.join(Opt.checkpoint_dir, 'netD-01.pt'))
                torch.save(netG.state_dict(), os.path.join(Opt.checkpoint_dir, 'netG-01.pt'))


if __name__ == '__main__':
    fixed_noise = Variable(torch.randn(100, Opt.nz).view(-1, Opt.nz, 1, 1).cuda())

    netG = netG()
    # netG.apply(weights_init)
    netD = netD()
    # netD.apply(weights_init)

    netG.cuda()
    netD.cuda()

    # Loss function
    criterion = torch.nn.BCELoss()

    # Optimizers
    optimizerG = torch.optim.Adam(netG.parameters(), lr=Opt.lr, betas=Opt.betas)
    optimizerD = torch.optim.Adam(netD.parameters(), lr=Opt.lr, betas=Opt.betas)

    train(netD, netG, criterion, optimizerG, optimizerD)
Ejemplo n.º 8
0
    # cv2.waitKey(0)

    combine_low = decomp_combine_image(ir_image, vi_image)
    input_ir = (ir_image - 127.5) / 127.5  # 将该幅图像的数据归一化
    input_vi = (vi_image - 127.5) / 127.5

    # 扩充输入数据的维度
    train_data_ir = np.expand_dims(input_ir, axis=0)
    train_data_vi = np.expand_dims(input_vi, axis=0)
    train_data_ir = np.expand_dims(train_data_ir, axis=3)
    train_data_vi = np.expand_dims(train_data_vi, axis=3)
    return train_data_ir, train_data_vi, combine_low


# 加载模型参数进行图像融合测试
fusion_model = netG().cuda().eval()
# print(fusion_model)
# discriminator = netD().cuda()
ep = 4
model_path = os.path.join(os.getcwd(), 'weight', 'epoch' + str(ep))
netG_path = os.path.join(model_path, 'netG.pth')
# netD_path = os.path.join(model_path, 'netD.pth')
fusion_model.load_state_dict(torch.load(netG_path))
# discriminator.load_state_dict(torch.load(netD_path))
data_ir = prepare_data('Test_ir')
data_vi = prepare_data('Test_vi')
for i in range(0, len(data_ir)):
    start = time.time()
    train_data_ir, train_data_vi, combine_low = input_setup(i)
    # 去掉尺寸为1的维度,得到可处理的图像数据
    # from_numpy得到的是DoubleTensor类型的,需要转成FloatTensor
def train(train_loader, netD, netG, criterion, optimizerG, optimizerD):
    for epoch in range(Opt.epochs):
        avg_lossD = 0
        avg_lossG = 0
        save_img = torch.zeros(16, 1, 128, 128)
        with open('logs.txt', 'a') as file:
            for i, sample_batched in enumerate(train_loader):
                image = sample_batched['image']
                label = sample_batched['mask']
                image = image.type(torch.FloatTensor)
                label = label.type(torch.FloatTensor)
                # Update D network
                mini_batch = label.shape[0]
                # train with real
                input = image * label
                input = Variable(input.cuda())  # image input
                real_label = Variable(torch.ones(mini_batch).cuda())
                output = netD(input)
                D_real_loss = criterion(output, real_label)
                # train with fake
                fake = netG(Variable(image.cuda()))
                # fake = Variable((fake > 0.5).type(torch.FloatTensor).cuda())
                fake_concat = fake * Variable(image.cuda())
                fake_label = Variable(torch.zeros(mini_batch).cuda())
                output = netD(fake_concat.detach()
                              )  # detach to avoid training G on these labels
                D_fake_loss = criterion(output, fake_label)
                D_loss = D_real_loss + D_fake_loss
                netD.zero_grad()
                D_loss.backward()
                if Opt.which_pc == 0:
                    avg_lossD += D_loss.item()
                else:
                    avg_lossD += D_loss.data[0]
                optimizerD.step()
                # Update G network
                G_loss1 = criterion(fake, Variable(label.cuda()))
                output = netD(fake_concat)
                G_loss2 = criterion(output, real_label)
                G_loss = G_loss2 + G_loss1
                if Opt.which_pc == 0:
                    avg_lossG += G_loss.item()
                else:
                    avg_lossG += G_loss.data[0]
                netG.zero_grad()
                G_loss.backward()
                optimizerG.step()

                print(
                    'Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f' %
                    (epoch + 1, Opt.epochs, i + 1, len(train_loader),
                     D_loss.data[0], G_loss.data[0]))

                # get generated images
                idice = intersect_index(Opt.save_img_id,
                                        sample_batched['img_id'])
                for j in range(len(idice)):
                    idx = idice[j]
                    save_img[idx[0]] = fake[idx[1]].data.cpu()

            avg_lossD /= i
            avg_lossG /= i
            print('epoch: ' + str(epoch + 1) + ', D_loss: ' + str(avg_lossD) +
                  ', G_loss: ' + str(avg_lossG))
            file.write('epoch: ' + str(epoch + 1) + ', D_loss: ' +
                       str(avg_lossD) + ', G_loss: ' + str(avg_lossG) + '\n')

        # save generated images
        vutils.save_image(save_img,
                          os.path.join(Opt.results_dir,
                                       'img' + str(epoch) + '.png'),
                          nrow=4,
                          scale_each=True)

        if epoch % 50 == 0:
            if Opt.save_model:
                torch.save(netD.state_dict(),
                           os.path.join(Opt.checkpoint_dir, 'netD-01.pt'))
                torch.save(netG.state_dict(),
                           os.path.join(Opt.checkpoint_dir, 'netG-01.pt'))