コード例 #1
0
ファイル: gan.py プロジェクト: dawnonme/Eureka
def train_D_Without_G():
    model = Discriminator()
    model.cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    learning_rate = 0.0001
    for epoch in range(100):
        avoidOverflow(optimizer)
        if (epoch == 50):
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate / 10.0
        if (epoch == 75):
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate / 100.0
        for batch_idx, (X_train_batch,
                        Y_train_batch) in enumerate(trainloader):

            if Y_train_batch.shape[0] < batch_size:
                continue

            X_train_batch = Variable(X_train_batch).cuda()
            Y_train_batch = Variable(Y_train_batch).cuda()
            _, output = model(X_train_batch)

            loss = criterion(output, Y_train_batch)
            optimizer.zero_grad()

            loss.backward()
            optimizer.step()

    torch.save(model, 'cifar10.model')
コード例 #2
0
def main(args):
    cfg = vars(args)
    cfg["session_dir"] = create_session_dir("./sessions")
    m_gen = Generator(cfg["z_dim"])
    m_disc = Discriminator()
    if len( cfg["resume_path"] ) > 0:
        cfg["session_dir"] = os.path.dirname( os.path.abspath(\
                cfg["resume_path"] ) )
        start_epoch = load_models(m_gen, m_disc, cfg)
        filemode = "a"
    else:
        start_epoch = 0
        filemode = "w"
    init_session_log(cfg, filemode)
    train_loader = get_loader(cfg)
    cudev = cfg["cuda"]
    if cudev >= 0 and not torch.cuda.is_available():
        raise RuntimeError("CUDA device specified but CUDA not available")
    if cudev >= 0:
        m_gen.cuda(cudev)
        m_disc.cuda(cudev)
    optG = torch.optim.SGD([{"params" : m_gen.parameters()}], lr=cfg["lr_g"],
        momentum=cfg["momentum"])
    optD = torch.optim.SGD([{"params" : m_disc.parameters()}], lr=cfg["lr_d"],
        momentum=cfg["momentum"])
    train(m_gen, m_disc, train_loader, (optD,optG), cfg, start_epoch)
コード例 #3
0
class LSGAN(object):
    def __init__(self, batch_size, adopt_gas=False):
        self.batch_size = batch_size
        self.generator = Generator(batch_size=self.batch_size, base_filter=32)
        self.discriminator = Discriminator(batch_size=self.batch_size,
                                           base_filter=32,
                                           adopt_gas=adopt_gas)
        self.generator.cuda()
        self.discriminator.cuda()
        self.gen_optimizer = RMSprop(self.generator.parameters())
        self.dis_optimizer = RMSprop(self.discriminator.parameters())

    def train(self, epoch, loader):
        self.generator.train()
        self.discriminator.train()
        self.gen_loss_sum = 0.0
        self.dis_loss_sum = 0.0
        for i, (batch_img, batch_tag) in enumerate(loader):
            # Get logits
            batch_img = Variable(batch_img.cuda())
            batch_z = Variable(torch.randn(self.batch_size, 100).cuda())
            self.gen_image = self.generator(batch_z)
            true_logits = self.discriminator(batch_img)
            fake_logits = self.discriminator(self.gen_image)

            # Get loss
            self.dis_loss = torch.sum((true_logits - 1)**2 + (fake_logits)) / 2
            self.gen_loss = torch.sum((fake_logits - 1)**2) / 2

            # Update
            self.dis_optimizer.zero_grad()
            self.dis_loss.backward(retain_graph=True)
            self.dis_loss_sum += self.dis_loss.data.cpu().numpy()[0]
            self.dis_optimizer.step()
            if i % 5 == 0:
                self.gen_optimizer.zero_grad()
                self.gen_loss.backward()
                self.gen_loss_sum += self.gen_loss.data.cpu().numpy()[0]
                self.gen_optimizer.step()

            if i > 300:
                break

    def eval(self):
        self.generator.eval()
        batch_z = Variable(torch.randn(32, 100).cuda())
        return self.generator(batch_z)
コード例 #4
0
    def build_model(self):
        if cfg.train.loss_type == cfg.VANILLA:
            self.loss = nn.BCELoss()
        elif cfg.train.loss_type == cfg.WGAN:
            self.loss = lambda logits, labels: torch.mean(logits)

        self.D_global = Discriminator(cfg.dataset.dataset_name)
        self.G_global = Generator(cfg.dataset.dataset_name)

        # Enable cuda if available
        if torch.cuda.is_available():
            self.D_global.cuda()
            self.G_global.cuda()

        # Optimizers
        self.D_global_optimizer = Adam(self.D_global.parameters(),
                                       lr=cfg.train.learning_rate,
                                       betas=(cfg.train.beta1, 0.999))
        self.G_global_optimizer = Adam(self.G_global.parameters(),
                                       lr=cfg.train.learning_rate,
                                       betas=(cfg.train.beta1, 0.999))

        self.D_pairs = []
        self.G_pairs = []
        self.D_pairs_optimizers = []
        self.G_pairs_optimizers = []

        self.D_msg_pairs = []
        self.D_msg_pairs_optimizers = []
        for id in range(1, cfg.train.N_pairs + 1):
            discriminator = Discriminator(cfg.dataset.dataset_name)
            generator = Generator(cfg.dataset.dataset_name)

            # Enable cuda if available
            if torch.cuda.is_available():
                generator.cuda()
                discriminator.cuda()

            self.D_pairs.append(discriminator)
            self.G_pairs.append(generator)

            # Optimizers
            D_optimizer = Adam(discriminator.parameters(),
                               lr=cfg.train.learning_rate,
                               betas=(cfg.train.beta1, 0.999))
            G_optimizer = Adam(generator.parameters(),
                               lr=cfg.train.learning_rate,
                               betas=(cfg.train.beta1, 0.999))

            self.D_pairs_optimizers.append(D_optimizer)
            self.G_pairs_optimizers.append(G_optimizer)

            # create msg Discriminator pair for G_global
            discriminator = Discriminator(cfg.dataset.dataset_name)

            # Enable cuda if available
            if torch.cuda.is_available():
                generator.cuda()
                discriminator.cuda()

            self.D_msg_pairs.append(discriminator)

            # Optimizers
            D_optimizer = Adam(discriminator.parameters(),
                               lr=cfg.train.learning_rate,
                               betas=(cfg.train.beta1, 0.999))

            self.D_msg_pairs_optimizers.append(D_optimizer)

        self.logger = Logger(model_name='DCGAN',
                             data_name='MNIST',
                             logdir=cfg.validation.validation_dir)

        return
コード例 #5
0
# 判别器
# discriminator = nn.Sequential(
#     nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(3, 3), stride=(1, 1), padding=1),
#     nn.LeakyReLU(0.3),
#     # nn.Dropout(0.3),
#     # nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=1),
#     # nn.LeakyReLU(0.3),
#     # nn.Dropout(0.3),
#     nn.Flatten(),
#     nn.Linear(in_features=5 * 3 * 4, out_features=1),
#     nn.Sigmoid()
# )
if use_gpu:
    generator = generator.cuda()
    discriminator = discriminator.cuda()

# 测试判别器
# discriminator.eval()
# decision = discriminator(generated_data)
# print(decision)

# loss function
cross_entropy = nn.BCELoss()


# 定义判别器损失函数。因为真实数据最后判别的结果必须要为1,因此real_loss的意思是真实的数据x和正确判别结果y(全部为1的矩阵)的loss值。而假的数据最后判别的结果是0,因此fake_loss的意思是生成的假数据x
# 和正确判别结果y(全部为0的矩阵的)的loss值。
def discriminator_loss(r_loss, f_loss):
    return r_loss + f_loss
コード例 #6
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # Define Networks
    generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
    discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim,
                                  d_filter_sizes, d_num_filters, d_dropout)
    if opt.cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    # Generate toy data using target lstm 也就是新建一个data,假装他就是真实数据
    print('啊啊')

    # Load data from file
    # 每个iter输出一个data和一个target,其中data是每个point前填个0,每个target是后面添个0
    gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)

    # Pretrain Generator using MLE
    gen_criterion = nn.NLLLoss(
        reduction='sum'
    )  #You may use CrossEntropyLoss instead, if you prefer not to add an extra LogSoftmax layer .
    gen_optimizer = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    print('Pretrain with MLE ...')
    for epoch in range(PRE_EPOCH_NUM):  #PRE_EPOCH_NUM =120
        loss = train_epoch(generator, gen_data_iter, gen_criterion,
                           gen_optimizer)  #使得generator的参数更新,使之适应gen_data_iter
        print('Epoch [%d] Model Loss: %f' % (epoch, loss))  # 9月1日

    # Pretrain Discriminator
    dis_criterion = nn.NLLLoss(reduction='sum')
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    print('Pretrain Discriminator ...')
    for epoch in range(5):
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE)
        for _ in range(3):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                               dis_optimizer)
            print('Epoch [%d], loss: %f' % (epoch, loss))
    # Adversarial Training
    rollout = Rollout(generator, 0.8)
    print('#####################################################')
    print('Start Adeversatial Training...\n')
    gen_gan_loss = GANLoss()
    gen_gan_optm = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_gan_loss = gen_gan_loss.cuda()
    gen_criterion = nn.NLLLoss(reduction='sum')
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    dis_criterion = nn.NLLLoss(reduction='sum')
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    for total_batch in range(TOTAL_BATCH):
        ## Train the generator for one step
        for it in range(1):
            samples = generator.sample(BATCH_SIZE, g_sequence_len)
            # construct the input to the genrator, add zeros before samples and delete the last column
            zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(
                torch.cat([zeros, samples.data], dim=1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view(
                (-1, ))  # 这里我不明白,为什么inputs:(batch_size,seq_len),而targets是一个序列
            # calculate the reward 为什么在rollout里reward不能直接直接由generator来sample
            rewards = rollout.get_reward(
                samples, 16, discriminator)  # rewards:(batch_size,seq_len)
            rewards = Variable(torch.Tensor(rewards))
            rewards = torch.exp(rewards).contiguous().view(
                (-1, ))  # 这是因为Discriminator的最后一层是log_Softmax
            if opt.cuda:
                rewards = rewards.cuda()
            prob = generator.forward(inputs)
            loss = gen_gan_loss(prob, targets, rewards)  #这里是点睛之笔
            gen_gan_optm.zero_grad()
            loss.backward()
            gen_gan_optm.step()  #这里更新的是rollout里面的ori_model吗?

        rollout.update_params()  #这里理解了的话,基本没问题了

        for _ in range(4):
            generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                             NEGATIVE_FILE)
            dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                        BATCH_SIZE)
            for _ in range(2):
                loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                                   dis_optimizer)

        print('Adversarial Training %d complete \n' % (total_batch))

    print('保存模型genetor')
    torch.save(generator.state_dict(), PATH_GPU)
コード例 #7
0
class WESPE(object):
    """docstring for WESPE"""
    def __init__(self, config):
        super(WESPE, self).__init__()
        
        self.config = config
        self.batch_size = config.batch_size
        self.patch_size = config.patch_size
        self.mode = config.mode
        self.channels = config.channels
        self.augmentation = config.augmentation
        self.checkpoint_dir = config.checkpoint_dir
        self.sample_dir = config.sample_dir
        self.result_img_dir = config.result_img_dir
        self.content_layer = config.content_layer
        self.vgg_dir = config.vgg_dir

        # Data
        self.dataset_name = config.dataset_name
        self.dataset = Dataset(self.config)
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=self.config.batch_size,
                                                       shuffle=True,
                                                       num_workers=self.config.data_loader_workers,
                                                       pin_memory=self.config.pin_memory,
                                                       drop_last=True)
        
        # self.dataset_phone = dataset_phone
        # self.dataset_canon = dataset_canon
        # self.dataset_DIV2K = dataset_DIV2K        

        # Loss Weights
        self.w_content = config.w_content
        self.w_profile = config.w_profile
        self.w_texture = config.w_texture 
        self.w_color = config.w_color
        self.w_tv = config.w_tv
        self.gamma = config.gamma

        # Total Losses
        self.total_profile_loss = 0
        self.total_color_loss = 0
        self.total_var_loss = 0
        self.total_texture_loss = 0
        self.total_content_loss = 0

        # Networks
        self.generator = Generator()
        self.discriminator1 = Discriminator(in_channels=3)
        self.discriminator2 = Discriminator(in_channels=3)
        self.discriminator3 = Discriminator(in_channels=1)
        
        if torch.cuda.is_available():
            self.generator, self.discriminator1, self.discriminator2, self.discriminator3 = \
            self.generator.cuda(), self.discriminator1.cuda(), self.discriminator2.cuda(), self.discriminator3.cuda()

        # Network Optimizers
        self.optimizer_G = torch.optim.Adam(self.generator.parameters())
        self.optimizer_D1 = torch.optim.Adam(self.discriminator1.parameters())
        self.optimizer_D2 = torch.optim.Adam(self.discriminator2.parameters())
        self.optimizer_D3 = torch.optim.Adam(self.discriminator3.parameters())

        # Discriminator Loss Function
        self.loss_fn_D = nn.BCEWithLogitsLoss()
        

    def build_discriminator_unit(self, generated_patch, actual_batch, index, preprocess):

        if index == 1:
            act, _ = self.discriminator1(actual_batch, preprocess = preprocess)
            fake, _ = self.discriminator1(generated_patch, preprocess = preprocess)

        elif index == 2:
            act, _ = self.discriminator2(actual_batch, preprocess = preprocess)
            fake, _ = self.discriminator2(generated_patch, preprocess = preprocess)

        elif index == 3:
            act, _ = self.discriminator3(actual_batch, preprocess = preprocess)
            fake, _ = self.discriminator3(generated_patch, preprocess = preprocess)

        else:
            raise NotImplementedError

        loss_real = self.loss_fn_D(act, torch.ones_like(act))
        loss_fake = self.loss_fn_D(fake, torch.zeros_like(fake))
        total_loss = loss_real+loss_fake

        return total_loss, act, fake
    
    
    
    def train(self):
        
        for i in range(self.config.train_iter):
            self.train_one_epoch(i)

            
    def train_one_epoch(self):
        
        start = time.time()
        
        self.total_profile_loss = 0
        self.total_color_loss = 0
        self.total_var_loss = 0
        self.total_texture_loss = 0
        self.total_content_loss = 0

        for i in range(self.config.train_iter):

            for step, (phone_patch, canon_patch, DIV2K_patch) in enumerate(self.data_loader):

                phone_patch, canon_patch, DIV2K_patch = phone_patch.float(), canon_patch.float(), DIV2K_patch.float()
                
                if torch.cuda.is_available():
                    phone_patch, canon_patch, DIV2K_patch = phone_patch.cuda(), canon_patch.cuda(), DIV2K_patch.cuda()
                
                self.optimizer_G.zero_grad()

                # Generator
                enhanced_patch = self.generator(phone_patch)
                
                # Discrimiator 1
                d_loss_profile, logits_DIV2K_profile, logits_enhanced_profile = self.build_discriminator_unit(enhanced_patch, DIV2K_patch, index=1, preprocess='blur')
                
                # Discrimiator 2
                d_loss_color, logits_original_color, logits_enhanced_color = self.build_discriminator_unit(enhanced_patch, canon_patch, index=2, preprocess='none')
                
                # Discrimiator 3
                d_loss_texture, logits_original_texture, logits_enhanced_texture = self.build_discriminator_unit(enhanced_patch, canon_patch, index=3, preprocess='gray')

                # Generator Loss
                original_vgg = net(self.vgg_dir, canon_patch * 255)
                enhanced_vgg = net(self.vgg_dir, enhanced_patch * 255)
                
                #content loss
                content_loss = torch.mean(torch.pow(original_vgg[self.content_layer] - enhanced_vgg[self.content_layer], 2))
                
                #profile loss(gan, enhanced-div2k)
                profile_loss = self.loss_fn_D(logits_DIV2K_profile, logits_enhanced_profile)
                
                # color loss (gan, enhanced-original)
                color_loss = self.loss_fn_D(logits_original_color, logits_enhanced_color)
                
                # texture loss (gan, enhanced-original)
                texture_loss = self.loss_fn_D(logits_original_texture, logits_enhanced_texture)
                
                # tv loss (total variation of enhanced)
                tv_loss = torch.mean(torch.abs(self.total_variation_loss(enhanced_patch) - self.total_variation_loss(canon_patch)))

                g_loss = self.w_content*content_loss + self.w_profile*profile_loss + self.w_color*color_loss + self.w_texture*texture_loss + self.w_tv*tv_loss

                g_loss.backward(retain_graph=True)
                self.optimizer_G.step()

                self.optimizer_D1.zero_grad()
                self.optimizer_D2.zero_grad()
                self.optimizer_D3.zero_grad()

                d_loss_profile.backward()
                self.optimizer_D1.step()

                d_loss_color.backward()
                self.optimizer_D2.step()

                d_loss_texture.backward()
                self.optimizer_D3.step()
                
                self.total_profile_loss += profile_loss
                self.total_color_loss += color_loss
                self.total_var_loss += tv_loss
                self.total_texture_loss += texture_loss
                self.total_content_loss += content_loss
                
                if i %self.config.test_every == 0:
                    print("Iteration %d, runtime: %.3f s, generator loss: %.6f" %(i, time.time() - start, g_loss))      
                    print("Loss per component: content %.6f,profile %.6f, color %.6f, texture %.6f, tv %.6f" %(content_loss, profile_loss, color_loss, texture_loss, tv_loss))
                    self.test_generator(100, 0)
    



    def test_generator(self, test_num_patch = 200, test_num_image = 5, load = False):
        
        self.generator.eval()
        
        # test for patches
        start = time.time()
        test_list_phone = sorted(glob(self.config.test_path_phone_patch))
        PSNR_phone_enhanced_list = np.zeros([test_num_patch])
        
        indexes = []
        for i in range(test_num_patch):
            index = np.random.randint(len(test_list_phone))
            indexes.append(index)
            test_img = scipy.misc.imread(test_list_phone[index], mode = "RGB").astype("float32")
            test_patch_phone = get_patch(test_img, self.config.patch_size)
            test_patch_phone = preprocess(test_patch_phone)
            
            with torch.no_grad():
                test_patch_phone = torch.from_numpy(np.transpose(test_patch_phone, (2,1,0))).float().unsqueeze(0)
                if torch.cuda.is_available():
                    test_patch_phone = test_patch_phone.cuda()

                test_patch_enhanced = self.generator(test_patch_phone)
            
            test_patch_enhanced = test_patch_enhanced.cpu().data.numpy()
            test_patch_enhanced = np.transpose(test_patch_enhanced.cpu().data.numpy(), (0,2,3,1))
            test_patch_phone = np.transpose(test_patch_phone.cpu().data.numpy(), (0,2,3,1))

            if i % 50 == 0:
                imageio.imwrite(("%s/phone_%d.png" %(self.result_img_dir, i)), postprocess(test_patch_phone[0]))
                imageio.imwrite(("%s/enhanced_%d.png" %(self.result_img_dir,i)), postprocess(test_patch_enhanced[0]))

            PSNR = calc_PSNR(postprocess(test_patch_enhanced[0]), postprocess(test_patch_phone))
            PSNR_phone_enhanced_list[i] = PSNR

        print("(runtime: %.3f s) Average test PSNR for %d random test image patches: phone-enhanced %.3f" %(time.time()-start, test_num_patch, np.mean(PSNR_phone_enhanced_list)))
        
        # test for images
        start = time.time()
        test_list_phone = sorted(glob(self.config.test_path_phone_image))
        PSNR_phone_enhanced_list = np.zeros([test_num_image])

        indexes = []
        for i in range(test_num_image):
            index = i
            indexes.append(index)
            
            test_image_phone = preprocess(scipy.misc.imread(test_list_phone[index], mode = "RGB").astype("float32"))
            
            with torch.no_grad():
                test_image_phone = torch.from_numpy(np.transpose(test_image_phone, (2,1,0))).float().unsqueeze(0)
                if torch.cuda.is_available():
                    test_image_phone = test_image_phone.cuda()

                test_image_enhanced = self.generator(test_image_phone)
            
            test_image_enhanced = test_image_enhanced.cpu().data.numpy()
            test_image_enhanced = np.transpose(test_image_enhanced.cpu().data.numpy(), (0,2,3,1))
            test_image_phone = np.transpose(test_image_phone.cpu().data.numpy(), (0,2,3,1))
                        
            imageio.imwrite(("%s/phone_%d.png" %(self.sample_dir, i)), postprocess(test_image_phone[0]))
            imageio.imwrite(("%s/enhanced_%d.png" %(self.sample_dir, i)), postprocess(test_image_enhanced[0]))
            
            PSNR = calc_PSNR(postprocess(test_image_enhanced[0]), postprocess(test_image_phone[0]))
            PSNR_phone_enhanced_list[i] = PSNR
            
        if test_num_image > 0:
            print("(runtime: %.3f s) Average test PSNR for %d random full test images: original-enhanced %.3f" %(time.time()-start, test_num_image, np.mean(PSNR_phone_enhanced_list)))


    def total_variation_loss(self, images):

        ndims = len(images.shape)

        if ndims == 3:
            pixel_dif1 = images[:, 1:, :] - images[:, :-1, :]
            pixel_dif2 = images[:, :, 1:] - images[:, :, :-1]
            sum_axis = None

        if ndims == 4:
            pixel_dif1 = images[:, :, 1:, :] - images[:, :, :-1, :]
            pixel_dif2 = images[:, :, :, 1:] - images[:, :, :, :-1]
            sum_axis = (1, 2, 3)

        else:
            raise ValueError('\'images\' must be either 3 or 4-dimensional.')

        tot_var = (
            torch.sum(torch.abs(pixel_dif1)) +
            torch.sum(torch.abs(pixel_dif2), dim=sum_axis))

        return tot_var
コード例 #8
0
        #model= torch.load(model_name, map_location=lambda storage, loc: storage)
        model.load_state_dict(
            torch.load(model_name, map_location=lambda storage, loc: storage))
        print('Pre-trained SR model is loaded.')

if opt.load_pretrained_D:
    D_name = os.path.join(opt.save_folder + opt.pretrained_D)
    if os.path.exists(D_name):
        #model= torch.load(model_name, map_location=lambda storage, loc: storage)
        D.load_state_dict(
            torch.load(D_name, map_location=lambda storage, loc: storage))
        print('Pre-trained Discriminator model is loaded.')

if cuda:
    model = model.cuda(gpus_list[0])
    D = D.cuda(gpus_list[0])
    feature_extractor = feature_extractor.cuda(gpus_list[0])
    MSE_loss = MSE_loss.cuda(gpus_list[0])
    BCE_loss = BCE_loss.cuda(gpus_list[0])

optimizer = optim.Adam(model.parameters(),
                       lr=opt.lr,
                       betas=(0.9, 0.999),
                       eps=1e-8)
D_optimizer = optim.Adam(D.parameters(),
                         lr=opt.lr,
                         betas=(0.9, 0.999),
                         eps=1e-8)

##PRETRAINED
if opt.pretrained:
コード例 #9
0
def train_d(args, dataset):
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.DEBUG)

    use_cuda = (torch.cuda.device_count() >= 1)

    # check checkpoints saving path
    if not os.path.exists('checkpoints/discriminator'):
        os.makedirs('checkpoints/discriminator')

    checkpoints_path = 'checkpoints/discriminator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['train_acc'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['valid_acc'] = AverageMeter()
    logging_meters['update_times'] = AverageMeter()

    # Build model
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)

    # Load generator
    assert os.path.exists('checkpoints/generator/best_gmodel.pt')
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load('checkpoints/generator/best_gmodel.pt')
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            # generator = torch.nn.DataParallel(generator).cuda()
            generator.cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    criterion = torch.nn.CrossEntropyLoss()

    # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()),
    #                                                     args.d_learning_rate, momentum=args.momentum, nesterov=True)

    optimizer = torch.optim.RMSprop(
        filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the accuracy achieve the define value
    max_epoch = args.max_epoch or math.inf
    epoch_i = 1
    trg_acc = 0.82
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']

    # validation set data loader (only prepare once)
    train = prepare_training_data(args, dataset, 'train', generator, epoch_i,
                                  use_cuda)
    valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i,
                                  use_cuda)
    data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len)
    data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len)

    # main training loop
    while lr > args.min_d_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        if args.sample_without_replacement > 0 and epoch_i > 1:
            train = prepare_training_data(args, dataset, 'train', generator,
                                          epoch_i, use_cuda)
            data_train = DatasetProcessing(data=train,
                                           maxlen=args.fixed_max_len)

        # discriminator training dataloader
        train_loader = train_dataloader(data_train,
                                        batch_size=args.joint_batch_size,
                                        seed=seed,
                                        epoch=epoch_i,
                                        sort_by_source_size=False)

        valid_loader = eval_dataloader(data_valid,
                                       num_workers=4,
                                       batch_size=args.joint_batch_size)

        # set training mode
        discriminator.train()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(train_loader):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=use_cuda)

            disc_out = discriminator(sample['src_tokens'],
                                     sample['trg_tokens'])

            loss = criterion(disc_out, sample['labels'])
            _, prediction = F.softmax(disc_out, dim=1).topk(1)
            acc = torch.sum(
                prediction == sample['labels'].unsqueeze(1)).float() / len(
                    sample['labels'])

            logging_meters['train_acc'].update(acc.item())
            logging_meters['train_loss'].update(loss.item())
            logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                          format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg,
                                 optimizer.param_groups[0]['lr'], i))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(discriminator.parameters(),
                                          args.clip_norm)
            optimizer.step()

            # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc
            del disc_out, loss, prediction, acc

        # set validation mode
        discriminator.eval()

        for i, sample in enumerate(valid_loader):
            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=use_cuda)

                disc_out = discriminator(sample['src_tokens'],
                                         sample['trg_tokens'])

                loss = criterion(disc_out, sample['labels'])
                _, prediction = F.softmax(disc_out, dim=1).topk(1)
                acc = torch.sum(
                    prediction == sample['labels'].unsqueeze(1)).float() / len(
                        sample['labels'])

                logging_meters['valid_acc'].update(acc.item())
                logging_meters['valid_loss'].update(loss.item())
                logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                              format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg,
                                     optimizer.param_groups[0]['lr'], i))

            del disc_out, loss, prediction, acc

        lr_scheduler.step(logging_meters['valid_loss'].avg)

        if logging_meters['valid_acc'].avg >= 0.70:
            torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \
                       .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i))

            if logging_meters['valid_loss'].avg < best_dev_loss:
                best_dev_loss = logging_meters['valid_loss'].avg
                torch.save(discriminator.state_dict(),
                           checkpoints_path + "best_dmodel.pt")

        # pretrain the discriminator to achieve accuracy 82%
        if logging_meters['valid_acc'].avg >= trg_acc:
            return

        epoch_i += 1
コード例 #10
0
ファイル: train.py プロジェクト: zhoutao1996/simpleGAN
def main(args):

    #transformer
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
    ])

    #dateset
    anime = AnimeData(args.tags, args.imgs, transform=transform)
    dataloder = DataLoader(anime, batch_size=args.bs, shuffle=True)

    #model
    gen = Generator(args.noise, momentum=args.momentum)
    dis = Discriminator(momentum=args.momentum)

    #criterion
    criterion = nn.BCELoss()

    if torch.cuda.is_available():
        gen = gen.cuda()
        dis = dis.cuda()
        criterion = criterion.cuda()

    #optimizer
    optimizer_gen = optim.Adam(gen.parameters(),
                               lr=args.lr_g,
                               betas=args.betas)
    optimizer_dis = optim.Adam(dis.parameters(),
                               lr=args.lr_d,
                               betas=args.betas)

    loss_history_d = []
    loss_history_g = []
    out_history_true_d = []
    out_history_fake_d = []
    out_history_fake_g = []

    for epoch in range(args.epochs):
        print('----------------start epoch %d ---------------' % epoch)
        step = 0
        for data in dataloder:
            step += 1
            start = time.time()
            img = Variable(data)
            noise = Variable(torch.randn(img.shape[0], args.noise))
            labels_true = Variable(torch.ones(img.shape[0], 1))
            labels_fake = Variable(torch.zeros(img.shape[0], 1))
            if args.label_smoothing:
                labels_true = labels_true - torch.rand(img.shape[0], 1) * 0.1
                labels_fake = labels_fake + torch.rand(img.shape[0], 1) * 0.1

            #train on GPU
            if torch.cuda.is_available():
                img = img.cuda()
                noise = noise.cuda()
                labels_true = labels_true.cuda()
                labels_fake = labels_fake.cuda()

            #train D
            out_true_d = dis(img)
            out_fake_d = dis(gen(noise))
            out_history_true_d.append(torch.mean(out_true_d).item())
            out_history_fake_d.append(torch.mean(out_fake_d).item())
            #d_loss_ture = -torch.mean(labels_true * torch.log(out_true_d) + (1. - labels_true) * torch.log(1. - out_true_d))
            loss_true_d = criterion(out_true_d, labels_true)
            loss_fake_d = criterion(out_fake_d, labels_fake)
            loss_d = loss_true_d + loss_fake_d
            optimizer_dis.zero_grad()
            loss_d.backward()

            if args.check:
                print('>>>>>>>>>>check_d_grad<<<<<<<<<<')
                try:
                    check_grad(dis, 'conv2.weight')
                except ValueError as e:
                    print(e)
                    show(loss_history_d, loss_history_g, out_history_true_d,
                         out_history_fake_d, out_history_fake_g)
                    torch.save(dis.state_dict(),
                               os.path.join(os.getcwd(), args.d, 'bad.pth'))
                    torch.save(gen.state_dict(),
                               os.path.join(os.getcwd(), args.g, 'bad.pth'))
                    return
            loss_history_d.append(loss_d.item())
            optimizer_dis.step()

            #train G
            noise = Variable(torch.randn(img.shape[0], args.noise))
            if torch.cuda.is_available():
                noise = noise.cuda()
            out_fake_g = dis(gen(noise))
            labels_fake = 1. - labels_fake
            out_history_fake_g.append(torch.mean(out_fake_g).item())
            loss_g = criterion(out_fake_g, labels_fake)
            optimizer_gen.zero_grad()
            loss_g.backward()

            if args.check:
                print('>>>>>>>>>>check_g_grad<<<<<<<<<<')
                try:
                    check_grad(gen, 'convTrans.weight')
                except ValueError as e:
                    print(e)
                    show(loss_history_d, loss_history_g, out_history_true_d,
                         out_history_fake_d, out_history_fake_g)
                    torch.save(dis.state_dict(),
                               os.path.join(os.getcwd(), args.d, 'bad.pth'))
                    torch.save(gen.state_dict(),
                               os.path.join(os.getcwd(), args.g, 'bad.pth'))
                    return
            loss_history_g.append(loss_g.item())
            optimizer_gen.step()
            end = time.time()
            print(
                'epoch: %d  step: %d  d_true: %.2f  d_fake: %.2f  g_fake: %.2f time: %.2f'
                % (epoch, step, out_history_true_d[-1], out_history_fake_d[-1],
                   out_history_fake_g[-1], end - start))

        #save model
        torch.save(dis.state_dict(),
                   os.path.join(os.getcwd(), args.d, '{}.pth'.format(epoch)))
        torch.save(gen.state_dict(),
                   os.path.join(os.getcwd(), args.g, '{}.pth'.format(epoch)))
コード例 #11
0
class SGAN:
    def __init__(self):
        self.read_dataset()
        if not os.path.exists(cfg.train.run_directory):
            os.makedirs(cfg.train.run_directory)
        with open(cfg.train.run_directory + 'params.txt', 'w') as f:
            f.write(str(vars(cfg)))

        self.build_model()

        return

    def read_dataset(self):
        self.train_loader, self.valid_loader = get_train_valid_loader(
            data_dir=cfg.dataset.data_dir,
            dataset_type=cfg.dataset.dataset_name,
            train_batch_size=cfg.train.batch_size,
            valid_batch_size=cfg.validation.batch_size,
            augment=False if cfg.dataset.dataset_name == 'mnist' else True,
            random_seed=cfg.dataset.seed,
            valid_size=cfg.train.valid_part,
            shuffle=True,
            show_sample=False,
            num_workers=multiprocessing.cpu_count(),
            pin_memory=False)

        return

    def real_data_target(self, size):
        '''
        Tensor containing ones, with shape = size
        '''
        data = Variable(torch.ones(size, 1))
        if torch.cuda.is_available(): return data.cuda()
        return data

    def fake_data_target(self, size):
        '''
        Tensor containing zeros, with shape = size
        '''
        data = Variable(torch.zeros(size, 1))
        if torch.cuda.is_available(): return data.cuda()
        return data

    def train_discriminator(self, discriminator, optimizer, real_data,
                            fake_data, labels):
        # Reset gradients
        optimizer.zero_grad()

        # 1. Train on Real Data
        D_real = discriminator(cfg.dataset.dataset_name, real_data, labels)
        # Calculate error and backpropagate
        D_loss_real = self.loss(D_real,
                                self.real_data_target(real_data.size(0)))
        D_loss_real.backward()

        # 2. Train on Fake Data
        D_fake = discriminator(cfg.dataset.dataset_name, fake_data, labels)
        # Calculate error and backpropagate
        D_loss_fake = self.loss(D_fake,
                                self.fake_data_target(fake_data.size(0)))
        D_loss_fake.backward()

        if cfg.train.loss_type == cfg.VANILLA:
            D_loss = D_loss_real + D_loss_fake
        elif cfg.train.loss_type == cfg.WGAN:
            D_loss = D_loss_fake - D_loss_real
            if cfg.train.use_GP:
                grad_penalty, gradient_norm = gradient_penalty(
                    discriminator, real_data, fake_data, cfg.train.gp_weight,
                    labels, cfg.dataset.dataset_name)
                D_loss += grad_penalty

        # Update weights with gradients
        optimizer.step()

        return D_real, D_fake, D_loss, D_loss_real, D_loss_fake

    def train_generator(self, generator, discriminator, optimizer, z_noise,
                        labels):
        # Reset gradients
        optimizer.zero_grad()

        # Sample noise and generate fake data
        G_fake_data = generator(cfg.dataset.dataset_name, z_noise, labels)
        D_fake = discriminator(cfg.dataset.dataset_name, G_fake_data, labels)
        # Calculate error and backpropagate
        G_loss = self.loss(D_fake, self.real_data_target(D_fake.size(0)))
        if cfg.train.loss_type == cfg.WGAN:
            G_loss = -1 * G_loss
        G_loss.backward()
        # Update weights with gradients
        optimizer.step()
        # Return error
        return G_fake_data, G_loss

    def build_model(self):
        if cfg.train.loss_type == cfg.VANILLA:
            self.loss = nn.BCELoss()
        elif cfg.train.loss_type == cfg.WGAN:
            self.loss = lambda logits, labels: torch.mean(logits)

        self.D_global = Discriminator(cfg.dataset.dataset_name)
        self.G_global = Generator(cfg.dataset.dataset_name)

        # Enable cuda if available
        if torch.cuda.is_available():
            self.D_global.cuda()
            self.G_global.cuda()

        # Optimizers
        self.D_global_optimizer = Adam(self.D_global.parameters(),
                                       lr=cfg.train.learning_rate,
                                       betas=(cfg.train.beta1, 0.999))
        self.G_global_optimizer = Adam(self.G_global.parameters(),
                                       lr=cfg.train.learning_rate,
                                       betas=(cfg.train.beta1, 0.999))

        self.D_pairs = []
        self.G_pairs = []
        self.D_pairs_optimizers = []
        self.G_pairs_optimizers = []

        self.D_msg_pairs = []
        self.D_msg_pairs_optimizers = []
        for id in range(1, cfg.train.N_pairs + 1):
            discriminator = Discriminator(cfg.dataset.dataset_name)
            generator = Generator(cfg.dataset.dataset_name)

            # Enable cuda if available
            if torch.cuda.is_available():
                generator.cuda()
                discriminator.cuda()

            self.D_pairs.append(discriminator)
            self.G_pairs.append(generator)

            # Optimizers
            D_optimizer = Adam(discriminator.parameters(),
                               lr=cfg.train.learning_rate,
                               betas=(cfg.train.beta1, 0.999))
            G_optimizer = Adam(generator.parameters(),
                               lr=cfg.train.learning_rate,
                               betas=(cfg.train.beta1, 0.999))

            self.D_pairs_optimizers.append(D_optimizer)
            self.G_pairs_optimizers.append(G_optimizer)

            # create msg Discriminator pair for G_global
            discriminator = Discriminator(cfg.dataset.dataset_name)

            # Enable cuda if available
            if torch.cuda.is_available():
                generator.cuda()
                discriminator.cuda()

            self.D_msg_pairs.append(discriminator)

            # Optimizers
            D_optimizer = Adam(discriminator.parameters(),
                               lr=cfg.train.learning_rate,
                               betas=(cfg.train.beta1, 0.999))

            self.D_msg_pairs_optimizers.append(D_optimizer)

        self.logger = Logger(model_name='DCGAN',
                             data_name='MNIST',
                             logdir=cfg.validation.validation_dir)

        return

    def run_validation(self, generator, discriminator, epoch, i, type_GAN):
        nrof_batches = len(self.valid_loader)
        for batch_idx, (valid_batch_images,
                        valid_batch_labels) in enumerate(self.valid_loader):
            valid_batch_size = len(valid_batch_images)
            valid_batch_labels = valid_batch_labels.type(torch.float32)
            valid_batch_z = torch.from_numpy(
                np.random.uniform(-1, 1,
                                  [valid_batch_size, cfg.train.z_dim]).astype(
                                      np.float32))

            if torch.cuda.is_available():
                valid_batch_images = valid_batch_images.cuda()
                valid_batch_labels = valid_batch_labels.cuda()
                valid_batch_z = valid_batch_z.cuda()

            G_fake_data = generator(cfg.dataset.dataset_name, valid_batch_z,
                                    valid_batch_labels)
            D_fake = discriminator(cfg.dataset.dataset_name, G_fake_data,
                                   valid_batch_labels)
            G_loss = self.loss(D_fake, self.real_data_target(D_fake.size(0)))

            D_real = discriminator(cfg.dataset.dataset_name,
                                   valid_batch_images, valid_batch_labels)
            D_loss_real = self.loss(
                D_real, self.real_data_target(valid_batch_images.size(0)))
            D_fake = discriminator(cfg.dataset.dataset_name, G_fake_data,
                                   valid_batch_labels)
            D_loss_fake = self.loss(D_fake,
                                    self.fake_data_target(D_fake.size(0)))
            D_loss = D_loss_real + D_loss_fake

            if len(valid_batch_images) == cfg.validation.batch_size:
                inception_score, std = Score.inception_score(G_fake_data)
                self.logger.log_score(inception_score, epoch, batch_idx,
                                      nrof_batches, type_GAN, 'IS_validation')

            # self.logger.log_images(generated_images, valid_batch_size, epoch, val_i, nrof_valid_batches,
            #                        type_GAN='pairs', format='NHWC')
            print("[Sample] d_loss: %.8f, g_loss: %.8f" % (D_loss, G_loss))
            if batch_idx > 0 and batch_idx % 15 == 0:
                generated_images = G_fake_data.detach().cpu()
                generated_images = generated_images.permute([0, 2, 3, 1])
                self.logger.log_images2(generated_images,
                                        epoch,
                                        batch_idx,
                                        type_GAN=type_GAN)

            batch_idx += 1

            # self.logger.save_models(self.G_pairs[id], self.D_pairs[id], epoch, 'pairs')
        return

    def copy_network_parameters(self, src_network, dest_network):
        params_src = src_network.named_parameters()
        params_dest = dest_network.named_parameters()

        dict_dest_params = dict(params_dest)

        for name_src, param_src in params_src:
            if name_src in dict_dest_params:
                dict_dest_params[name_src].data.copy_(param_src.data)
        return

    def run_train(self):
        for epoch in range(cfg.train.num_epochs):
            for id in range(cfg.train.N_pairs):
                print('Train pairs')
                self.train_pairs_epoch(id, epoch)
                self.copy_network_parameters(self.D_pairs[id],
                                             self.D_msg_pairs[id])
                self.train_G_global_epoch(id, epoch)
                self.train_D_global_epoch(id, epoch)
                self.run_validation(self.G_global, self.D_global, epoch, None,
                                    'global_pair')
                self.logger.save_models(self.G_global, self.D_global, epoch,
                                        'global_pair')
        return

    def train_D_global_epoch(self, id, epoch):
        # torch.set_default_tensor_type('torch.DoubleTensor')
        nrof_batches = len(self.train_loader)
        train_time = 0
        for batch_idx, (batch_images,
                        batch_labels) in enumerate(self.train_loader):
            start_time = time.time()
            batch_size = len(batch_images)
            batch_labels = batch_labels.type(torch.float32)
            batch_z = torch.from_numpy(
                np.random.uniform(-1, 1, [batch_size, cfg.train.z_dim]).astype(
                    np.float32))

            # 1. Train Discriminator
            if torch.cuda.is_available():
                batch_images = batch_images.cuda()
                batch_labels = batch_labels.cuda()
                batch_z = batch_z.cuda()
            # Generate fake data
            G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z,
                                           batch_labels).detach()
            # Train D
            D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator(
                self.D_global, self.D_global_optimizer, batch_images,
                G_fake_data, batch_labels)

            # 2. Train Generator
            G_fake_data, G_loss = self.train_generator(
                self.G_pairs[id], self.D_global, self.G_pairs_optimizers[id],
                batch_z, batch_labels)

            # 3. Train Discriminator twice
            # Generate fake data
            G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z,
                                           batch_labels).detach()
            # Train D
            D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator(
                self.D_global, self.D_global_optimizer, batch_images,
                G_fake_data, batch_labels)

            # Log error
            self.logger.log(D_loss, G_loss, epoch, batch_idx, nrof_batches,
                            'D0-' + str(id + 1))

            if len(batch_images) == cfg.train.batch_size:
                inception_score, std = Score.inception_score(G_fake_data)
                self.logger.log_score(inception_score, epoch, batch_idx,
                                      nrof_batches, 'D0-' + str(id + 1), 'IS')

            duration = time.time() - start_time
            print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                  % (epoch, cfg.train.num_epochs, batch_idx, nrof_batches,
                     time.time() - start_time, D_loss, G_loss))
            train_time += duration
            if batch_idx > 0 and batch_idx % 101 == 0:
                self.run_validation(self.G_pairs[id], self.D_global, epoch,
                                    batch_idx, 'D_global_pairs-' + str(id + 1))
            batch_idx += 1

        self.logger.save_models(self.G_pairs[id], self.D_global, epoch,
                                'D_global_pairs-' + str(id + 1))

        return

    def train_G_global_epoch(self, id, epoch):
        # torch.set_default_tensor_type('torch.DoubleTensor')
        nrof_batches = len(self.train_loader)
        train_time = 0
        for batch_idx, (batch_images,
                        batch_labels) in enumerate(self.train_loader):
            start_time = time.time()
            batch_size = len(batch_images)
            batch_labels = batch_labels.type(torch.float32)
            batch_z = torch.from_numpy(
                np.random.uniform(-1, 1, [batch_size, cfg.train.z_dim]).astype(
                    np.float32))

            # 1. Train Discriminator
            if torch.cuda.is_available():
                batch_images = batch_images.cuda()
                batch_labels = batch_labels.cuda()
                batch_z = batch_z.cuda()
            # Generate fake data
            G_fake_data = self.G_global(cfg.dataset.dataset_name, batch_z,
                                        batch_labels).detach()
            # Train D
            D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator(
                self.D_msg_pairs[id], self.D_msg_pairs_optimizers[id],
                batch_images, G_fake_data, batch_labels)

            # 2. Train Generator
            G_fake_data, G_loss = self.train_generator(self.G_global,
                                                       self.D_msg_pairs[id],
                                                       self.G_global_optimizer,
                                                       batch_z, batch_labels)

            # 3. Train Discriminator twice
            # Generate fake data
            G_fake_data = self.G_global(cfg.dataset.dataset_name, batch_z,
                                        batch_labels).detach()
            # Train D
            D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator(
                self.D_msg_pairs[id], self.D_msg_pairs_optimizers[id],
                batch_images, G_fake_data, batch_labels)

            # Log error
            self.logger.log(D_loss, G_loss, epoch, batch_idx, nrof_batches,
                            'G0-' + str(id + 1))

            if len(batch_images) == cfg.train.batch_size:
                inception_score, std = Score.inception_score(G_fake_data)
                self.logger.log_score(inception_score, epoch, batch_idx,
                                      nrof_batches, 'G0-' + str(id + 1), 'IS')

            duration = time.time() - start_time
            print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                  % (epoch, cfg.train.num_epochs, batch_idx, nrof_batches,
                     time.time() - start_time, D_loss, G_loss))
            train_time += duration
            if batch_idx > 0 and batch_idx % 101 == 0:
                self.run_validation(self.G_global, self.D_msg_pairs[id], epoch,
                                    batch_idx, 'G_global_pairs-' + str(id + 1))
            batch_idx += 1

        self.logger.save_models(self.G_global, self.D_msg_pairs[id], epoch,
                                'G_global_pairs-' + str(id + 1))

        return

    def train_pairs_epoch(self, id, epoch):
        nrof_batches = len(self.train_loader)
        train_time = 0
        for batch_idx, (batch_images,
                        batch_labels) in enumerate(self.train_loader):
            start_time = time.time()
            batch_size = len(batch_images)
            batch_labels = batch_labels.type(torch.float32)
            batch_z = torch.from_numpy(
                np.random.uniform(-1, 1, [batch_size, cfg.train.z_dim]).astype(
                    np.float32))

            # 1. Train Discriminator
            if torch.cuda.is_available():
                batch_images = batch_images.cuda()
                batch_labels = batch_labels.cuda()
                batch_z = batch_z.cuda()
            # Generate fake data
            G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z,
                                           batch_labels).detach()
            # Train D
            D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator(
                self.D_pairs[id], self.D_pairs_optimizers[id], batch_images,
                G_fake_data, batch_labels)

            # 2. Train Generator
            G_fake_data, G_loss = self.train_generator(
                self.G_pairs[id], self.D_pairs[id],
                self.G_pairs_optimizers[id], batch_z, batch_labels)

            # 3. Train Discriminator twice
            # Generate fake data
            G_fake_data = self.G_pairs[id](cfg.dataset.dataset_name, batch_z,
                                           batch_labels).detach()
            # Train D
            D_real, D_fake, D_loss, D_loss_real, D_loss_fake = self.train_discriminator(
                self.D_pairs[id], self.D_pairs_optimizers[id], batch_images,
                G_fake_data, batch_labels)

            # Log error
            self.logger.log(D_loss, G_loss, epoch, batch_idx, nrof_batches,
                            str(id + 1))

            if len(batch_images) == cfg.train.batch_size:
                inception_score, std = Score.inception_score(G_fake_data)
                self.logger.log_score(inception_score, epoch, batch_idx,
                                      nrof_batches, str(id + 1), 'IS')

            duration = time.time() - start_time
            print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                  % (epoch, cfg.train.num_epochs, batch_idx, nrof_batches,
                     time.time() - start_time, D_loss, G_loss))
            train_time += duration
            if batch_idx > 0 and batch_idx % 101 == 0:
                self.run_validation(self.G_pairs[id], self.D_pairs[id], epoch,
                                    batch_idx, 'pairs-' + str(id + 1))

        self.logger.save_models(self.G_pairs[id], self.D_pairs[id], epoch,
                                'pairs-' + str(id + 1))

        return
コード例 #12
0
def train(opt):

    netG_A2B = Unet2(3, 3)
    netG_B2A = Unet2(3, 3)
    netD_A = Discriminator(3)
    netD_B = Discriminator(3)

    if opt.use_cuda:
        netG_A2B = netG_A2B.cuda()
        netG_B2A = netG_B2A.cuda()
        netD_A = netD_A.cuda()
        netD_B = netD_B.cuda()

    netG_A2B_optimizer = optimizer.Adam(params=netG_A2B.parameters(),
                                        lr=opt.lr,
                                        betas=(0.5, 0.999))
    netG_B2A_optimizer = optimizer.Adam(params=netG_B2A.parameters(),
                                        lr=opt.lr,
                                        betas=(0.5, 0.999))
    netD_A_optimizer = optimizer.Adam(params=netD_A.parameters(),
                                      lr=opt.lr,
                                      betas=(0.5, 0.999))
    netD_B_optimizer = optimizer.Adam(params=netD_B.parameters(),
                                      lr=opt.lr,
                                      betas=(0.5, 0.999))

    optimizers = dict()
    optimizers['G1'] = netG_A2B_optimizer
    optimizers['G2'] = netG_B2A_optimizer
    optimizers['D1'] = netD_A_optimizer
    optimizers['D2'] = netD_B_optimizer

    # Dataset loader
    transforms_ = [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    tarindataloader = DataLoader(ImageDataset(opt.dataroot,
                                              transforms_=transforms_,
                                              unaligned=True),
                                 batch_size=opt.batchSize,
                                 shuffle=True)

    #writer
    writer = SummaryWriter(opt.log_dir)

    for epoch in range(0, opt.n_epochs):
        for ii, batch in enumerate(tarindataloader):
            # Set model input
            real_A = Variable(batch['A'])
            real_B = Variable(batch['B'])

            if opt.use_cuda:
                real_A = real_A.cuda()
                real_B = real_B.cuda()

            train_one_step(use_cuda=opt.use_cuda,
                           netG_A2B=netG_A2B,
                           netG_B2A=netG_B2A,
                           netD_A=netD_A,
                           netD_B=netD_B,
                           real_A=real_A,
                           real_B=real_B,
                           optimizers=optimizers,
                           iteration=ii,
                           writer=writer)

            print("\nEpoch: %s Batch: %s" % (epoch, ii))

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
    torch.save(netG_A2B.state_dict(),
               os.path.join(opt.save_dir, '%s' % "netG_A2B"))
    torch.save(netG_B2A.state_dict(),
               os.path.join(opt.save_dir, '%s' % "netG_B2A"))
    torch.save(netD_A.state_dict(), os.path.join(opt.save_dir,
                                                 '%s' % "netD_A"))
    torch.save(netD_B.state_dict(), os.path.join(opt.save_dir,
                                                 '%s' % "netD_B"))
コード例 #13
0
ファイル: main.py プロジェクト: ivan-selchenkov/dcgan
                          shuffle=True,
                          num_workers=0)

display_data(train_loader)

conv_size = 32
z_size = 100

D = Discriminator(conv_size)
G = Generator(z_size, conv_size)

cuda = False

if torch.cuda.is_available():
    cuda = True
    D = D.cuda()
    G = G.cuda()

lr = 0.0002
beta1 = 0.5
beta2 = 0.99

d_optim = optim.Adam(D.parameters(), lr, [beta1, beta2])
g_optim = optim.Adam(G.parameters(), lr, [beta1, beta2])


def train_discriminator(real_images, optimizer, batch_size, z_size):
    optimizer.zero_grad()

    if cuda:
        real_images = real_images.cuda()
コード例 #14
0
def main(options):

    use_cuda = (len(options.gpuid) >= 1)
    # if options.gpuid:
    #   cuda.set_device(options.gpuid[0])

    src_train, src_dev, src_test, src_vocab = torch.load(
        open(options.data_file + "." + options.src_lang, 'rb'))
    trg_train, trg_dev, trg_test, trg_vocab = torch.load(
        open(options.data_file + "." + options.trg_lang, 'rb'))

    batched_train_src, batched_train_src_mask, sort_index = utils.tensor.advanced_batchize(
        src_train, options.batch_size, src_vocab.stoi["<blank>"])
    batched_train_trg, batched_train_trg_mask = utils.tensor.advanced_batchize_no_sort(
        trg_train, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)
    batched_dev_src, batched_dev_src_mask, sort_index = utils.tensor.advanced_batchize(
        src_dev, options.batch_size, src_vocab.stoi["<blank>"])
    batched_dev_trg, batched_dev_trg_mask = utils.tensor.advanced_batchize_no_sort(
        trg_dev, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)

    print "preprocessing batched data..."
    processed_src = list()
    processed_trg = list()
    processed_src_mask = list()
    processed_trg_mask = list()
    for batch_i in range(len(batched_train_src)):
        if batched_train_src[batch_i].size(
                0) <= 35 and batched_train_trg[batch_i].size(0) <= 35:
            processed_src.append(batched_train_src[batch_i])
            processed_trg.append(batched_train_trg[batch_i])
            processed_src_mask.append(batched_train_src_mask[batch_i])
            processed_trg_mask.append(batched_train_trg_mask[batch_i])

    batched_train_src = processed_src
    batched_train_trg = processed_trg
    batched_train_src_mask = processed_src_mask
    batched_train_trg_mask = processed_trg_mask

    processed_src = list()
    processed_trg = list()
    processed_src_mask = list()
    processed_trg_mask = list()
    for batch_i in range(len(batched_dev_src)):
        if batched_dev_src[batch_i].size(
                0) <= 35 and batched_dev_trg[batch_i].size(0) <= 35:
            processed_src.append(batched_dev_src[batch_i])
            processed_trg.append(batched_dev_trg[batch_i])
            processed_src_mask.append(batched_dev_src_mask[batch_i])
            processed_trg_mask.append(batched_dev_trg_mask[batch_i])

    batched_dev_src = processed_src
    batched_dev_trg = processed_trg
    batched_dev_src_mask = processed_src_mask
    batched_dev_trg_mask = processed_trg_mask

    del processed_src, processed_trg, processed_trg_mask, processed_src_mask

    trg_vocab_size = len(trg_vocab)
    src_vocab_size = len(src_vocab)
    word_emb_size = 50
    hidden_size = 1024

    nmt = NMT(src_vocab_size,
              trg_vocab_size,
              word_emb_size,
              hidden_size,
              src_vocab,
              trg_vocab,
              attn_model="general",
              use_cuda=True)
    discriminator = Discriminator(src_vocab_size,
                                  trg_vocab_size,
                                  word_emb_size,
                                  src_vocab,
                                  trg_vocab,
                                  use_cuda=True)

    if use_cuda > 0:
        #nmt = torch.nn.DataParallel(nmt,device_ids=options.gpuid).cuda()
        nmt.cuda()
        #discriminator = torch.nn.DataParallel(discriminator,device_ids=options.gpuid).cuda()
        discriminator.cuda()
    else:
        nmt.cpu()
        discriminator.cpu()

    criterion_g = torch.nn.NLLLoss().cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda()

    # Configure optimization
    optimizer_g = eval("torch.optim." + options.optimizer)(
        nmt.parameters(), options.learning_rate)
    optimizer_d = eval("torch.optim." + options.optimizer)(
        discriminator.parameters(), options.learning_rate)

    # main training loop
    f1 = open("train_loss", "a")
    f2 = open("dev_loss", "a")
    last_dev_avg_loss = float("inf")
    for epoch_i in range(options.epochs):
        logging.info("At {0}-th epoch.".format(epoch_i))
        # srange generates a lazy sequence of shuffled range

        train_loss_g = 0.0
        train_loss_d = 0.0
        train_loss_g_nll = 0.0
        train_loss_g_ce = 0.0
        train_loss_nll_batch_num = 0
        train_loss_ce_batch_num = 0
        for i, batch_i in enumerate(utils.rand.srange(len(batched_train_src))):
            if i == 1500:
                break
            # if i==5:
            #   break
            train_src_batch = Variable(batched_train_src[batch_i]
                                       )  # of size (src_seq_len, batch_size)
            train_trg_batch = Variable(batched_train_trg[batch_i]
                                       )  # of size (src_seq_len, batch_size)
            train_src_mask = Variable(batched_train_src_mask[batch_i])
            train_trg_mask = Variable(batched_train_trg_mask[batch_i])
            if use_cuda:
                train_src_batch = train_src_batch.cuda()
                train_trg_batch = train_trg_batch.cuda()
                train_src_mask = train_src_mask.cuda()
                train_trg_mask = train_trg_mask.cuda()

            # train discriminator
            sys_out_batch = nmt(train_src_batch, train_trg_batch,
                                True).detach()
            _, predict_batch = sys_out_batch.topk(1)
            del _
            predict_batch = predict_batch.squeeze(2)
            real_dis_label_out = discriminator(train_src_batch,
                                               train_trg_batch, True)
            fake_dis_label_out = discriminator(train_src_batch, predict_batch,
                                               True)
            optimizer_d.zero_grad()
            loss_d_real = criterion(
                real_dis_label_out,
                Variable(
                    torch.ones(options.batch_size *
                               len(options.gpuid)).long()).cuda())
            loss_d_real.backward()
            loss_d_fake = criterion(
                fake_dis_label_out,
                Variable(
                    torch.zeros(options.batch_size *
                                len(options.gpuid)).long()).cuda())
            #loss_d_fake.backward(retain_graph=True)
            loss_d_fake.backward()
            loss_d = loss_d_fake.data[0] + loss_d_real.data[0]
            del loss_d_fake, loss_d_real
            logging.debug("D loss at batch {0}: {1}".format(i, loss_d))
            f1.write("D train loss at batch {0}: {1}\n".format(i, loss_d))
            optimizer_d.step()

            if use_cuda > 0:
                sys_out_batch = sys_out_batch.cuda()
                train_trg_batch = train_trg_batch.cuda()
            else:
                sys_out_batch = sys_out_batch.cpu()
                train_trg_batch = train_trg_batch.cpu()

            # train nmt
            sys_out_batch = nmt(train_src_batch, train_trg_batch, True)
            _, predict_batch = sys_out_batch.topk(1)
            predict_batch = predict_batch.squeeze(2)
            fake_dis_label_out = discriminator(train_src_batch, predict_batch,
                                               True)
            if random.random() > 0.5:
                train_trg_mask = train_trg_mask.view(-1)
                train_trg_batch = train_trg_batch.view(-1)
                train_trg_batch = train_trg_batch.masked_select(train_trg_mask)
                train_trg_mask = train_trg_mask.unsqueeze(1).expand(
                    len(train_trg_mask), trg_vocab_size)
                sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
                sys_out_batch = sys_out_batch.masked_select(
                    train_trg_mask).view(-1, trg_vocab_size)
                loss_g = criterion_g(sys_out_batch, train_trg_batch)
                train_loss_g_nll += loss_g
                train_loss_nll_batch_num += 1
                f1.write("G train NLL loss at batch {0}: {1}\n".format(
                    i, loss_g.data[0]))
            else:
                loss_g = criterion(
                    fake_dis_label_out,
                    Variable(
                        torch.ones(options.batch_size *
                                   len(options.gpuid)).long()).cuda())
                train_loss_g_ce += loss_g
                train_loss_ce_batch_num += 1
                f1.write("G train CE loss at batch {0}: {1}\n".format(
                    i, loss_g.data[0]))

            logging.debug("G loss at batch {0}: {1}".format(i, loss_g.data[0]))

            optimizer_g.zero_grad()
            loss_g.backward()

            # # gradient clipping
            torch.nn.utils.clip_grad_norm(nmt.parameters(), 5.0)
            optimizer_g.step()

            train_loss_d += loss_d
        train_avg_loss_g_nll = train_loss_g_nll / train_loss_nll_batch_num
        train_avg_loss_g_ce = train_loss_g_ce / train_loss_ce_batch_num
        train_avg_loss_d = train_loss_d / len(train_src_batch)
        logging.info(
            "G TRAIN Average NLL loss value per instance is {0} at the end of epoch {1}"
            .format(train_avg_loss_g_nll, epoch_i))
        logging.info(
            "G TRAIN Average CE loss value per instance is {0} at the end of epoch {1}"
            .format(train_avg_loss_g_ce, epoch_i))
        logging.info(
            "D TRAIN Average loss value per instance is {0} at the end of epoch {1}"
            .format(train_avg_loss_d, epoch_i))

        # validation -- this is a crude esitmation because there might be some paddings at the end
        # dev_loss_g_nll = 0.0
        # dev_loss_g_ce = 0.0
        # dev_loss_d = 0.0

        # for batch_i in range(len(batched_dev_src)):
        #   dev_src_batch = Variable(batched_dev_src[batch_i], volatile=True)
        #   dev_trg_batch = Variable(batched_dev_trg[batch_i], volatile=True)
        #   dev_src_mask = Variable(batched_dev_src_mask[batch_i], volatile=True)
        #   dev_trg_mask = Variable(batched_dev_trg_mask[batch_i], volatile=True)
        #   if use_cuda:
        #     dev_src_batch = dev_src_batch.cuda()
        #     dev_trg_batch = dev_trg_batch.cuda()
        #     dev_src_mask = dev_src_mask.cuda()
        #     dev_trg_mask = dev_trg_mask.cuda()

        #   sys_out_batch = nmt(dev_src_batch, dev_trg_batch, False).detach()
        #   _,predict_batch = sys_out_batch.topk(1)
        #   predict_batch = predict_batch.squeeze(2)
        #   real_dis_label_out = discriminator(dev_src_batch, dev_trg_batch, True).detach()
        #   fake_dis_label_out = discriminator(dev_src_batch, predict_batch, True).detach()

        #   if use_cuda > 0:
        #     sys_out_batch = sys_out_batch.cuda()
        #     dev_trg_batch = dev_trg_batch.cuda()
        #   else:
        #     sys_out_batch = sys_out_batch.cpu()
        #     dev_trg_batch = dev_trg_batch.cpu()

        #   dev_trg_mask = dev_trg_mask.view(-1)
        #   dev_trg_batch = dev_trg_batch.view(-1)
        #   dev_trg_batch = dev_trg_batch.masked_select(dev_trg_mask)
        #   dev_trg_mask = dev_trg_mask.unsqueeze(1).expand(len(dev_trg_mask), trg_vocab_size)
        #   sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
        #   sys_out_batch = sys_out_batch.masked_select(dev_trg_mask).view(-1, trg_vocab_size)
        #   loss_g_nll = criterion_g(sys_out_batch, dev_trg_batch)
        #   loss_g_ce = criterion(fake_dis_label_out, Variable(torch.ones(options.batch_size*len(options.gpuid)).long(),volatile=True).cuda())
        #   loss_d = criterion(real_dis_label_out, Variable(torch.ones(options.batch_size*len(options.gpuid)).long(),volatile=True).cuda()) + criterion(fake_dis_label_out, Variable(torch.zeros(options.batch_size*len(options.gpuid)).long(),volatile=True).cuda())
        #   logging.debug("G dev NLL loss at batch {0}: {1}".format(batch_i, loss_g_nll.data[0]))
        #   logging.debug("G dev CE loss at batch {0}: {1}".format(batch_i, loss_g_ce.data[0]))
        #   f2.write("G dev NLL loss at batch {0}: {1}\n".format(batch_i, loss_g_nll.data[0]))
        #   f2.write("G dev CE loss at batch {0}: {1}\n".format(batch_i, loss_g_ce.data[0]))
        #   logging.debug("D dev loss at batch {0}: {1}".format(batch_i, loss_d.data[0]))
        #   f2.write("D dev loss at batch {0}: {1}\n".format(batch_i, loss_d.data[0]))
        #   dev_loss_g_nll += loss_g_nll
        #   dev_loss_g_ce += loss_g_ce
        #   dev_loss_d += loss_d
        # dev_avg_loss_g_nll = dev_loss_g_nll / len(batched_dev_src)
        # dev_avg_loss_g_ce = dev_loss_g_ce / len(batched_dev_src)
        # dev_avg_loss_d = dev_loss_d / len(batched_dev_src)
        # logging.info("G DEV Average NLL loss value per instance is {0} at the end of epoch {1}".format(dev_avg_loss_g_nll.cpu().data[0], epoch_i))
        # logging.info("G DEV Average CE loss value per instance is {0} at the end of epoch {1}".format(dev_avg_loss_g_ce.cpu().data[0], epoch_i))
        # logging.info("D DEV Average loss value per instance is {0} at the end of epoch {1}".format(dev_avg_loss_d.data[0], epoch_i))
        # # if (last_dev_avg_loss - dev_avg_loss).data[0] < options.estop:
        # #   logging.info("Early stopping triggered with threshold {0} (previous dev loss: {1}, current: {2})".format(epoch_i, last_dev_avg_loss.data[0], dev_avg_loss.data[0]))
        # #   break
    torch.save(nmt,
               open(
                   "nmt.nll_{0:.2f}.epoch_{1}".format(
                       train_avg_loss_g_nll.cpu().data[0], epoch_i), 'wb'),
               pickle_module=dill)
    torch.save(discriminator,
               open(
                   "discriminator.nll_{0:.2f}.epoch_{1}".format(
                       train_avg_loss_d.data[0], epoch_i), 'wb'),
               pickle_module=dill)
    f1.close()
    f2.close()
コード例 #15
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    calc_bleu([1, 10, 12])
    exit()
    # Build up dataset
    s_train, s_test = load_from_big_file('../data/train_data_obama.txt')
    # idx_to_word: List of id to word
    # word_to_idx: Dictionary mapping word to id
    idx_to_word, word_to_idx = fetch_vocab(s_train, s_train, s_test)
    # TODO: 1. Prepare data for attention model
    # input_seq, target_seq = prepare_data(DATA_GERMAN, DATA_ENGLISH, word_to_idx)

    global VOCAB_SIZE
    VOCAB_SIZE = len(idx_to_word)

    save_vocab(CHECKPOINT_PATH + 'metadata.data', idx_to_word, word_to_idx,
               VOCAB_SIZE, g_emb_dim, g_hidden_dim, g_sequence_len)

    print('VOCAB SIZE:', VOCAB_SIZE)
    # Define Networks
    generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, g_sequence_len,
                          BATCH_SIZE, opt.cuda)
    discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim,
                                  d_filter_sizes, d_num_filters, d_dropout)
    target_lstm = TargetLSTM(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
    if opt.cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        target_lstm = target_lstm.cuda()
    # Generate toy data using target lstm
    print('Generating data ...')
    generate_real_data('../data/train_data_obama.txt', BATCH_SIZE,
                       GENERATED_NUM, idx_to_word, word_to_idx, POSITIVE_FILE,
                       TEST_FILE)
    # Create Test data iterator for testing
    test_iter = GenDataIter(TEST_FILE, BATCH_SIZE)
    # generate_samples(target_lstm, BATCH_SIZE, GENERATED_NUM, POSITIVE_FILE, idx_to_word)

    # Load data from file
    gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)

    # Pretrain Generator using MLE
    # gen_criterion = nn.NLLLoss(size_average=False)
    gen_criterion = nn.CrossEntropyLoss()
    gen_optimizer = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    print('Pretrain with MLE ...')
    for epoch in range(PRE_EPOCH_NUM):
        loss = train_epoch(generator, gen_data_iter, gen_criterion,
                           gen_optimizer)
        print('Epoch [%d] Model Loss: %f' % (epoch, loss))
        print('Training Output')
        test_predict(generator, test_iter, idx_to_word, train_mode=True)

        sys.stdout.flush()
        # TODO: 2. Flags to ensure dimension of model input is handled
        # generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        """
        eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
        print('Iterator Done')
        loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
        print('Epoch [%d] True Loss: %f' % (epoch, loss))
        """
    print('OUTPUT AFTER PRE-TRAINING')
    test_predict(generator, test_iter, idx_to_word, train_mode=True)

    # Pretrain Discriminator
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    print('Pretrain Discriminator ...')
    for epoch in range(3):
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE)
        for _ in range(3):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                               dis_optimizer)
            print('Epoch [%d], loss: %f' % (epoch, loss))
            sys.stdout.flush()
    # Adversarial Training
    rollout = Rollout(generator, 0.8)
    print('#####################################################')
    print('Start Adversarial Training...\n')
    gen_gan_loss = GANLoss()

    gen_gan_optm = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_gan_loss = gen_gan_loss.cuda()
    gen_criterion = nn.NLLLoss(size_average=False)
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    real_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)
    for total_batch in range(TOTAL_BATCH):
        ## Train the generator for one step
        for it in range(1):
            if real_iter.idx >= real_iter.data_num:
                real_iter.reset()
            inputs = real_iter.next()[0]
            inputs = inputs.cuda()
            samples = generator.sample(BATCH_SIZE, g_sequence_len, inputs)
            samples = samples.cpu()
            rewards = rollout.get_reward(samples, 16, discriminator)
            rewards = Variable(torch.Tensor(rewards))
            if opt.cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1, ))
            prob = generator.forward(inputs)
            mini_batch = prob.shape[0]
            prob = torch.reshape(
                prob,
                (prob.shape[0] * prob.shape[1], -1))  #prob.view(-1, g_emb_dim)
            targets = copy.deepcopy(inputs).contiguous().view((-1, ))
            loss = gen_gan_loss(prob, targets, rewards)
            gen_gan_optm.zero_grad()
            loss.backward()
            gen_gan_optm.step()
            """
            samples = generator.sample(BATCH_SIZE, g_sequence_len)
            # construct the input to the genrator, add zeros before samples and delete the last column
            zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(torch.cat([zeros, samples.data], dim = 1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view((-1,))
            print('', inputs.shape, targets.shape)
            print(inputs, targets)
            # calculate the reward
            rewards = rollout.get_reward(samples, 16, discriminator)
            rewards = Variable(torch.Tensor(rewards))
            if opt.cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1,))
            prob = generator.forward(inputs)
            mini_batch = prob.shape[0]
            prob = torch.reshape(prob, (prob.shape[0] * prob.shape[1], -1)) #prob.view(-1, g_emb_dim)
            loss = gen_gan_loss(prob, targets, rewards)
            gen_gan_optm.zero_grad()
            loss.backward()
            gen_gan_optm.step()
            """
        print('Batch [%d] True Loss: %f' % (total_batch, loss))

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            # generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
            # eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
            # loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
            if len(prob.shape) > 2:
                prob = torch.reshape(prob, (prob.shape[0] * prob.shape[1], -1))
            predictions = torch.max(prob, dim=1)[1]
            predictions = predictions.view(mini_batch, -1)
            for each_sen in list(predictions):
                print('Train Output:',
                      generate_sentence_from_id(idx_to_word, each_sen))

            test_predict(generator, test_iter, idx_to_word, train_mode=True)
            torch.save(generator.state_dict(),
                       CHECKPOINT_PATH + 'generator.model')
            torch.save(discriminator.state_dict(),
                       CHECKPOINT_PATH + 'discriminator.model')
        rollout.update_params()

        for _ in range(4):
            generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                             NEGATIVE_FILE)
            dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                        BATCH_SIZE)
            for _ in range(2):
                loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                                   dis_optimizer)
コード例 #16
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(
            args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(
            args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0.3
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0.3
    args.bidirectional = False

    generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator_h = Discriminator_h(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda)
    print("Discriminator_h loaded successfully!")
    discriminator_s = Discriminator_s(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    print("Discriminator_s loaded successfully!")

    def _calcualte_discriminator_loss(tf_scores, ar_scores):
        tf_loss = torch.log(tf_scores + 1e-9) * (-1)
        ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1)
        return tf_loss + ar_loss

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator_h = torch.nn.DataParallel(discriminator_h).cuda()
            discriminator_s = torch.nn.DataParallel(discriminator_s).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator_h.cuda()
            discriminator_s.cuda()
    else:
        discriminator_h.cpu()
        discriminator_s.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/professor2'):
        os.makedirs('checkpoints/professor2')
    checkpoints_path = 'checkpoints/professor2/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True)

    # fix discriminator_h word embedding (as Wu et al. do)
    for p in discriminator_s.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator_s.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(lambda x: x.requires_grad,
                                                                 generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer_h = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad,
                                                                 discriminator_h.parameters()),
                                                          args.d_learning_rate,
                                                          momentum=args.momentum,
                                                          nesterov=True)

    d_optimizer_s = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad,
                                                                 discriminator_s.parameters()),
                                                          args.d_learning_rate,
                                                          momentum=args.momentum,
                                                          nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator_h.train()
        discriminator_s.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer)

        for i, sample in enumerate(trainloader):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator
            # print("Policy Gradient Training")
            sys_out_batch_PG, p_PG, hidden_list_PG = generator('PG', epoch_i, sample)  # 64 X 50 X 6632

            out_batch_PG = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 * 50) X 6632

            _, prediction = out_batch_PG.topk(1)
            prediction = prediction.squeeze(1)  # 64*50 = 3200
            prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape)  # 64 X 50

            with torch.no_grad():
                reward = discriminator_s(sample['net_input']['src_tokens'], prediction)  # 64 X 1

            train_trg_batch_PG = sample['target']  # 64 x 50

            pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda)
            sample_size_PG = sample['target'].size(0) if args.sentence_avg else sample['ntokens']  # 64
            logging_loss_PG = pg_loss_PG / math.log(2)
            g_logging_meters['train_loss'].update(logging_loss_PG.item(), sample_size_PG)
            logging.debug(
                f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}")
            g_optimizer.zero_grad()
            pg_loss_PG.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm)
            g_optimizer.step()

            # print("MLE Training")
            sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator("MLE", epoch_i, sample)

            out_batch_MLE = sys_out_batch_MLE.contiguous().view(-1, sys_out_batch_MLE.size(-1))  # (64 X 50) X 6632

            train_trg_batch_MLE = sample['target'].view(-1)  # 64*50 = 3200
            loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE)

            sample_size_MLE = sample['target'].size(0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2)
            g_logging_meters['bsz'].update(nsentences)
            g_logging_meters['train_loss'].update(logging_loss_MLE, sample_size_MLE)
            logging.debug(
                f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}")
            g_optimizer.zero_grad()
            loss_MLE.backward(retain_graph=True)
            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                # print(p.size())
                if p.requires_grad:
                    p.grad.data.div_(sample_size_MLE)
            torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm)
            g_optimizer.step()

            num_update += 1


            #  part II: train the discriminator

            # discriminator_h
            if num_update % 5 == 0:

                d_MLE = discriminator_h(hidden_list_MLE)
                d_PG = discriminator_h(hidden_list_PG)
                d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum()
                logging.debug(f"D_h training loss {d_loss} at batch {i}")

                d_optimizer_h.zero_grad()
                d_loss.backward()
                torch.nn.utils.clip_grad_norm_(discriminator_h.parameters(), args.clip_norm)
                d_optimizer_h.step()




                #discriminator_s
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input']['src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = torch.ones(sample['target'].size(0)).float()  # 64 length vector
                with torch.no_grad():
                    sys_out_batch, p, hidden_list = generator('MLE', epoch_i, sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = torch.zeros(sample['target'].size(0)).float()  # 64 length vector

                fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence, src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                fake_disc_out = discriminator_s(src_sentence, fake_sentence)  # 64 X 1
                true_disc_out = discriminator_s(src_sentence, true_sentence)

                fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)

                acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)

                d_loss = fake_d_loss + true_d_loss

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D_s training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}")
                d_optimizer_s.zero_grad()
                d_loss.backward()
                d_optimizer_s.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator_h.eval()
        discriminator_s.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        valloader = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(valloader):

            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)

                # generator validation
                sys_out_batch_test, p_test, hidden_list_test = generator('test', epoch_i, sample)
                out_batch_test = sys_out_batch_test.contiguous().view(-1, sys_out_batch_test.size(-1))  # (64 X 50) X 6632
                dev_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss_test = g_criterion(out_batch_test, dev_trg_batch)
                sample_size_test = sample['target'].size(0) if args.sentence_avg else sample['ntokens']
                loss_test = loss_test / sample_size_test / math.log(2)
                g_logging_meters['valid_loss'].update(loss_test, sample_size_test)
                logging.debug(f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}")

                # # discriminator_h validation
                # bsz = sample['target'].size(0)
                # src_sentence = sample['net_input']['src_tokens']
                # # train with half human-translation and half machine translation
                # true_sentence = sample['target']
                # true_labels = torch.ones(sample['target'].size(0)).float()
                # with torch.no_grad():
                #     sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample)
                #
                # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 X 50) X 6632
                # _, prediction = out_batch.topk(1)
                # prediction = prediction.squeeze(1)  # 64 * 50 = 6632
                # fake_labels = torch.zeros(sample['target'].size(0)).float()
                # fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                # if use_cuda:
                #     fake_labels = fake_labels.cuda()
                # disc_out = discriminator_h(src_sentence, fake_sentence)
                # d_loss = d_criterion(disc_out.squeeze(1), fake_labels)
                # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # d_logging_meters['valid_acc'].update(acc)
                # d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug(
                #     f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}")

        torch.save(generator,
                   open(checkpoints_path + f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt",
                        'wb'), pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
コード例 #17
0
ファイル: gan_model.py プロジェクト: Olafyii/Motion_Transfer
class gan(nn.Module):
    # def __init__(self, params, save_dir, g_weight_dir, d_weight_dir, d_update_freq=1, start_epoch=0, g_lr=2e-4, d_lr=2e-4, use_cuda=True):
    def __init__(self, params, args):
        super(gan, self).__init__()
        self.G = MModel(params, use_cuda=True)
        self.D = Discriminator(params, bias=True)
        self.vgg_loss = VGGPerceptualLoss()
        self.L1_loss = nn.L1Loss()
        if args.use_cuda:
            self.G = self.G.cuda()
            self.D = self.D.cuda()
            self.vgg_loss = self.vgg_loss.cuda()
            self.L1_loss = self.L1_loss.cuda()
        if args.g_weight_dir:
            self.G.load_state_dict(torch.load(args.g_weight_dir), strict=True)
        if args.d_weight_dir:
            self.D.load_state_dict(torch.load(args.d_weight_dir), strict=False)

        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=args.g_lr)
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=args.d_lr)

        self.save_dir = args.save_dir
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        self.d_update_freq = args.d_update_freq
        self.save_freq = args.save_freq
        self.writer = SummaryWriter('runs/' + args.save_dir)
        self.use_cuda = args.use_cuda
        self.start_epoch = args.start_epoch

    def G_loss(self, input, target):
        vgg = self.vgg_loss(input, target)
        L1 = self.L1_loss(input, target)
        # return vgg
        return vgg + L1

    def update_D(self, loss, epoch):
        if epoch % self.d_update_freq == 0:
            loss.backward()
            self.optimizer_D.step()

    def get_patch_weight(self, pose, size=62):
        heads = pose[:, 0, :, :]
        heads = heads.unsqueeze(1)
        heads = torch.nn.functional.interpolate(heads, size=size)
        heads = heads * 5 + torch.ones_like(heads)
        return heads

    def gan_loss(self, out, label, pose):
        # weight = self.get_patch_weight(pose)
        # return nn.BCELoss(weight=weight)(out, torch.ones_like(out) if label==1 else torch.zeros_like(out))
        return nn.BCELoss()(
            out, torch.ones_like(out) if label == 1 else torch.zeros_like(out))

    def train(self, dl, epoch):  # i -- current epoch
        cnt = 0
        loss_D_real_sum, loss_D_fake_sum, loss_D_sum, loss_G_gan_sum, loss_G_img_sum, loss_G_sum = 0, 0, 0, 0, 0, 0
        for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans,
                   src_mask_gt, tgt_face, tgt_face_box,
                   src_face_box) in enumerate(dl):
            print('epoch:', epoch, 'iter:', iter)
            self.optimizer_D.zero_grad()
            if self.use_cuda:
                src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda(
                ), y.cuda(), src_pose.cuda(), tgt_pose.cuda(
                ), src_mask_prior.cuda(), x_trans.cuda()

            out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans)
            gen = out[0]
            loss_D_real = self.gan_loss(self.D(y, tgt_pose), 1, tgt_pose)
            loss_D_fake = self.gan_loss(self.D(gen.detach(), tgt_pose), 0,
                                        tgt_pose)
            loss_D = loss_D_real + loss_D_fake
            self.update_D(loss_D, epoch)

            if False and epoch < 10:
                loss_G_gan = torch.zeros((1))
                loss_G_img = torch.zeros((1))
                loss_G = loss_G_gan + loss_G_img
            else:
                self.optimizer_G.zero_grad()
                loss_G_gan = self.gan_loss(self.D(gen, tgt_pose), 1, tgt_pose)
                loss_G_img = self.G_loss(gen, y)  # vgg_loss + L1_loss
                loss_G = loss_G_gan + loss_G_img
                loss_G.backward()
                self.optimizer_G.step()

            loss_D_real_sum += loss_D_real.item()
            loss_D_fake_sum += loss_D_fake.item()
            loss_D_sum += loss_D.item()
            loss_G_gan_sum += loss_G_gan.item()
            loss_G_img_sum += loss_G_img.item()
            loss_G_sum += loss_G.item()
            cnt += 1

            # if epoch % self.save_freq == 0 and iter < 3:
            #     self.writer.add_images('gen/epoch%d'%epoch, gen*0.5+0.5)
            #     self.writer.add_images('y/epoch%d'%epoch, y*0.5+0.5)
            #     self.writer.add_images('src_mask/epoch%d'%epoch, out[2].view((out[2].size(0)*out[2].size(1), 1, out[2].size(2), out[2].size(3))))
            #     self.writer.add_images('warped/epoch%d'%epoch, out[3].view((out[3].size(0)*11, 3, out[3].size(2), out[3].size(3)))*0.5+0.5)

        self.writer.add_scalar('loss_D_real', loss_D_real_sum / cnt, epoch)
        self.writer.add_scalar('loss_D_fake', loss_D_fake_sum / cnt, epoch)
        self.writer.add_scalar('loss_D', loss_D_sum / cnt, epoch)
        self.writer.add_scalar('loss_G_gan', loss_G_gan_sum / cnt, epoch)
        self.writer.add_scalar('loss_G_img', loss_G_img_sum / cnt, epoch)
        self.writer.add_scalar('loss_G', loss_G_sum / cnt, epoch)
        self.writer.add_scalars('DG', {
            'D': loss_D / cnt,
            'G': loss_G / cnt
        }, epoch)
        if epoch % self.save_freq == 0:
            torch.save(self.G.state_dict(),
                       os.path.join(self.save_dir, 'g_epoch_%d.pth' % epoch))
            torch.save(self.D.state_dict(),
                       os.path.join(self.save_dir, 'd_epoch_%d.pth' % epoch))

    def test(self, test_dl, epoch):
        self.G.eval()
        for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans,
                   src_mask_gt, tgt_face, tgt_face_box,
                   src_face_box) in enumerate(test_dl):
            print('test', 'epoch:', epoch, 'iter:', iter)
            if self.use_cuda:
                src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda(
                ), y.cuda(), src_pose.cuda(), tgt_pose.cuda(
                ), src_mask_prior.cuda(), x_trans.cuda()
            with torch.no_grad():
                out = self.G(src_img, src_pose, tgt_pose, src_mask_prior,
                             x_trans)
            gen = out[0]
            if iter == 0:
                self.writer.add_images('test_gen/epoch%d' % epoch,
                                       gen * 0.5 + 0.5)
                self.writer.add_images('test_y/epoch%d' % epoch, y * 0.5 + 0.5)
                self.writer.add_images('test_src/epoch%d' % epoch,
                                       src_img * 0.5 + 0.5)
                self.writer.add_images(
                    'test_src_mask/epoch%d' % epoch, out[2].view(
                        (out[2].size(0) * out[2].size(1), 1, out[2].size(2),
                         out[2].size(3))))
コード例 #18
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    g_model_path = 'checkpoints/zhenwarm/generator.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")
    #
    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/myzhencli5'):
        os.makedirs('checkpoints/myzhencli5')
    checkpoints_path = 'checkpoints/myzhencli5/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(trainloader):

            # set training mode
            generator.train()
            discriminator.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch = generator(sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            # part II: train the discriminator
            if num_update % 5 == 0:
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = Variable(
                    torch.ones(
                        sample['target'].size(0)).float())  # 64 length vector

                with torch.no_grad():
                    sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = Variable(
                    torch.zeros(
                        sample['target'].size(0)).float())  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence,
                                              src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                # fake_disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1
                # true_disc_out = discriminator(src_sentence, true_sentence)
                #
                # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #
                # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (fake_acc + true_acc) / 2
                #
                # d_loss = fake_d_loss + true_d_loss
                if random.random() > 0.5:
                    fake_disc_out = discriminator(src_sentence, fake_sentence)
                    fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                              fake_labels)
                    fake_acc = torch.sum(
                        torch.round(fake_disc_out).squeeze(1) ==
                        fake_labels).float() / len(fake_labels)
                    d_loss = fake_d_loss
                    acc = fake_acc
                else:
                    true_disc_out = discriminator(src_sentence, true_sentence)
                    true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                              true_labels)
                    true_acc = torch.sum(
                        torch.round(true_disc_out).squeeze(1) ==
                        true_labels).float() / len(true_labels)
                    d_loss = true_d_loss
                    acc = true_acc

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

            if num_update % 10000 == 0:

                # validation
                # set validation mode
                generator.eval()
                discriminator.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss = g_criterion(out_batch, dev_trg_batch)
                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss / sample_size / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss, sample_size)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                        # discriminator validation
                        bsz = sample['target'].size(0)
                        src_sentence = sample['net_input']['src_tokens']
                        # train with half human-translation and half machine translation

                        true_sentence = sample['target']
                        true_labels = Variable(
                            torch.ones(sample['target'].size(0)).float())

                        with torch.no_grad():
                            sys_out_batch = generator(sample)

                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                        _, prediction = out_batch.topk(1)
                        prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                        fake_labels = Variable(
                            torch.zeros(sample['target'].size(0)).float())

                        fake_sentence = torch.reshape(
                            prediction, src_sentence.shape)  # 64 X 50
                        true_sentence = torch.reshape(true_sentence,
                                                      src_sentence.shape)
                        if use_cuda:
                            fake_labels = fake_labels.cuda()
                            true_labels = true_labels.cuda()

                        fake_disc_out = discriminator(src_sentence,
                                                      fake_sentence)  # 64 X 1
                        true_disc_out = discriminator(src_sentence,
                                                      true_sentence)

                        fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                                  fake_labels)
                        true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                                  true_labels)
                        d_loss = fake_d_loss + true_d_loss
                        fake_acc = torch.sum(
                            torch.round(fake_disc_out).squeeze(1) ==
                            fake_labels).float() / len(fake_labels)
                        true_acc = torch.sum(
                            torch.round(true_disc_out).squeeze(1) ==
                            true_labels).float() / len(true_labels)
                        acc = (fake_acc + true_acc) / 2
                        d_logging_meters['valid_acc'].update(acc)
                        d_logging_meters['valid_loss'].update(d_loss)
                        logging.debug(
                            f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                        )

                # torch.save(discriminator,
                #            open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill)

                # if d_logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = d_logging_meters['valid_loss'].avg
                #     torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill)

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
コード例 #19
0
ファイル: wgan.py プロジェクト: Qingyan1218/GAN
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size),
             transforms.ToTensor(),
             transforms.Normalize([0.5], [0.5])]    # [] means channel, 0.5,0.5 means mean & std
                                                    # => img = (img - mean) / 0.5 per channel
        ),
    ),
コード例 #20
0
class LOSS(nn.Module):
    def __init__(self, lr, batch_size, alpha, beta, image_size, K, T, gpu):
        super(LOSS, self).__init__()

        self.K = K
        self.T = T
        self.alpha = alpha
        self.beta = beta
        self.batch_size = batch_size

        # define network and criterion
        self.mcnet = MCnet()
        self.discriminator = Discriminator(K, T)
        self.criterion = nn.BCELoss()

        # define value variable for training, it can convenient for multiple network
        self.true_data = torch.FloatTensor(batch_size, K + T, image_size,
                                           image_size)
        self.true_data_seq = torch.FloatTensor(batch_size, 1, image_size,
                                               image_size, K + T)
        self.fake_data_diff = torch.FloatTensor(batch_size, 1, image_size,
                                                image_size, K - 1)
        self.fake_data_xt = torch.FloatTensor(batch_size, 1, image_size,
                                              image_size)
        self.label = torch.FloatTensor(batch_size)
        self.real_label = 1
        self.fake_label = 0

        if gpu:
            self.mcnet.cuda()
            self.discriminator.cuda()
            self.true_data = self.true_data.cuda()
            self.true_data_seq = self.true_data_seq.cuda()
            self.fake_data_diff = self.fake_data_diff.cuda()
            self.fake_data_xt = self.fake_data_xt.cuda()
            self.label = self.label.cuda()

        self.true_data = Variable(self.true_data)
        self.true_data_seq = Variable(self.true_data_seq)
        self.fake_data_diff = Variable(self.fake_data_diff)
        self.fake_data_xt = Variable(self.fake_data_xt)
        self.label = Variable(self.label)

        # define optimizer for each network to update weight
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr)
        self.optimizer_G = optim.Adam(self.mcnet.parameters(), lr)

    #def forward(self, diff_batch, seq_batch):
    def forward(self, diff_batch, seq_batch, pic_batch, train=True):
        """
        compute loss of Mcnet
        :param diff_batch: subtraction between of t and t-1 frame
        :param seq_batch: video sequence of T+K frame
        :return: discrimination loss and generation loss and predict value with cpu
        """
        if train:
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ############################

            self.discriminator.zero_grad()
            # train with real
            true_data_cpu = seq_batch.permute(0, 4, 2, 3, 1).contiguous(
            )[:, :, :, :, 0]  # sequence as channel [batch,seq+channel,H,W]
            self.true_data.data.resize_(true_data_cpu.size()).copy_(
                true_data_cpu)  # copy data as Variable with gpu
            self.label.data.resize_(self.batch_size).fill_(
                self.real_label)  # copy label as Variable with gpu

            self.true_data = self.true_data[:, 0:self.
                                            K, :, :]  # discriminator is first K frame

            true_dis = self.discriminator(
                self.true_data)  # truth data for discriminator
            d_loss_real = self.criterion(
                true_dis, self.label)  # cross entropy for criterion
            d_loss_real.backward()  # computer gradient

            # train with fake
            #xt_cpu = seq_batch[:, :, :, :, self.K - 1] # picture of last frame
            xt_cpu = pic_batch
            self.fake_data_diff.data.resize_(diff_batch.size()).copy_(
                diff_batch)  # copy diff data as Variable with gpu
            self.fake_data_xt.data.resize_(xt_cpu.size()).copy_(
                xt_cpu)  # copy last frame data as Variable with gpu
            self.true_data_seq.data.resize_(seq_batch.size()).copy_(
                seq_batch)  # copy seq data as Variable with gpu

            output_list, gram = self.mcnet(
                self.fake_data_diff,
                self.fake_data_xt)  # generate data of Mcnet
            predict = torch.cat(
                output_list,
                4)  # concatenate gen data of T seq [batch,channel,H,W,seq]
            gen_data = torch.cat(
                [self.true_data_seq[:, :, :, :, :self.K], predict
                 ],  # concatenate prior K data and sequence as channel
                4).permute(0, 4, 2, 3, 1).contiguous()[:, :, :, :, 0]
            self.label.data.fill_(self.fake_label)

            gen_data = gen_data[:, self.K:self.K +
                                self.T, :, :]  # discriminator is first K frame

            gen_dis = self.discriminator(gen_data.detach())
            d_loss_fake = self.criterion(gen_dis, self.label)
            d_loss_fake.backward()

            self.optimizer_D.step()  # Adam update weight
            d_loss = d_loss_fake + d_loss_real  # discrimination loss

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            self.mcnet.zero_grad()
            self.label.data.fill_(self.real_label)
            gen_dis = self.discriminator(gen_data)
            d_loss_gan = self.criterion(gen_dis, self.label)

            #L_img = self.loss_img(self.true_data_seq, predict) # compute L_img

            # (3) Gram Matrix Loss
            gram_loss = self.loss_gram(gram)

            #g_loss = self.alpha * L_img + self.beta * d_loss_gan # generation loss
            g_loss = self.alpha * gram_loss + self.beta * d_loss_gan  # generation loss
            g_loss.backward()

            self.optimizer_G.step()

            # show the parameter of Discriminator
            #D_bias = self.discriminator.conv1.bias.data[0:3].cpu().view(1,-1).numpy()
            #D_weight = self.discriminator.conv1.weight.data[0,0,0,0:3].cpu().view(1,-1).numpy()

            #return d_loss, g_loss, predict.data.cpu()
            return d_loss, g_loss, gram_loss, predict.data.cpu()
        else:
            xt_cpu = pic_batch
            self.fake_data_diff.data.resize_(diff_batch.size()).copy_(
                diff_batch)  # copy diff data as Variable with gpu
            self.fake_data_xt.data.resize_(xt_cpu.size()).copy_(
                xt_cpu)  # copy last frame data as Variable with gpu
            self.true_data_seq.data.resize_(seq_batch.size()).copy_(
                seq_batch)  # copy seq data as Variable with gpu

            output_list, gram = self.mcnet(
                self.fake_data_diff,
                self.fake_data_xt)  # generate data of Mcnet
            predict = torch.cat(
                output_list,
                4)  # concatenate gen data of T seq [batch,channel,H,W,seq]

            return predict

    def loss_img(self, target, predict):
        # convert data to gray with 3 channel and combine batch and sequence
        true_sim = target[:, :, :, :, self.K:].add(1.0).div(2.0)
        #true_sim = target[:, :, :, :, :self.K].add(1.0).div(2.0)
        true_sim = true_sim.repeat(1, 3, 1, 1, 1).permute(0, 4, 1, 2,
                                                          3).contiguous()

        true_sim = true_sim.view(-1, true_sim.size(2), true_sim.size(3),
                                 true_sim.size(4))

        gen_sim = predict.add(1.0).div(2.0)
        gen_sim = gen_sim.repeat(1, 3, 1, 1, 1).permute(0, 4, 1, 2,
                                                        3).contiguous()
        gen_sim = gen_sim.view(-1, gen_sim.size(2), gen_sim.size(3),
                               gen_sim.size(4))

        loss_p = self.loss_p(target[:, :, :, :, self.K:], predict, 2.0)
        #loss_p = self.loss_p(target[:, :, :, :, :self.K], predict, 2.0)
        loss_gld = self.loss_gld(true_sim, gen_sim, 1.0)
        L_img = loss_p + loss_gld

        return L_img

    def loss_p(self, tar, pre, p):
        """
        loss_p = mean(||tar - pre||_2^2)
        :param tar: ground truth value
        :param pre: predict value
        :p: hyper-parameters of loss_p
        :return: loss_p
        """
        return torch.mean((pre - tar)**p)

    def loss_gld(self, tar, pre, alpha):
        """
        match the gradients of such pixel values
        mean(|(|y_{i,j}-y_{i-1,j}| - |z_{i,j}-z_{i-1,j}|)|^n +
        |(|y_{i,j-1}-y_{i,j}| - |z_{i,j-1}-z_{i,j}|)|^n)

        :param tar: ground truth value
        :param pre: predict value
        :alpha: hyper-parameters of loss_gld
        :return: loss_gld
        """
        pos = torch.eye(3)
        neg = -1 * pos

        # weight for conv is [out_channel,in_channel,kH,kW]
        # subtraction between center and left
        weight_x = torch.zeros([3, 3, 1, 2])
        weight_x[:, :, 0, 0] = neg
        weight_x[:, :, 0, 1] = pos

        # subtraction between center and up
        weight_y = torch.zeros([3, 3, 2, 1])
        weight_y[:, :, 0, 0] = pos
        weight_y[:, :, 1, 0] = neg

        weight_x = Variable(weight_x.cuda())
        weight_y = Variable(weight_y.cuda())

        gen_dx = torch.abs(F.conv2d(pre, weight_x, padding=1))
        gen_dy = torch.abs(F.conv2d(pre, weight_y, padding=1))
        true_dx = torch.abs(F.conv2d(tar, weight_x, padding=1))
        true_dy = torch.abs(F.conv2d(tar, weight_y, padding=1))

        grad_diff_x = torch.abs(true_dx - gen_dx)
        grad_diff_y = torch.abs(true_dy - gen_dy)

        return torch.mean(grad_diff_x**alpha + grad_diff_y**alpha)

    def loss_gram(self, gram):

        loss = 0.0
        for t in xrange(len(gram)):
            # gram_s,gram_C:top to bottom. gram_f bottom to top
            (gram_s, gram_c, gram_f) = gram[t]

            # loss and grad of style
            (loss_s, grad_s) = style_transfer(gram_f, gram_s)

            # loss and grad of content
            (loss_c, grad_c) = content_transfer(gram_f, gram_c)

            for l in xrange(len(loss_s)):
                loss = loss + loss_s[l] + loss_c[l]

        loss = loss / len(gram)

        return loss
コード例 #21
0
def main(pretrain_dataset, rl_dataset, args):
    ##############################################################################
    # Setup
    ##############################################################################
    # set random seeds
    random.seed(const.SEED)
    np.random.seed(const.SEED)

    # load datasets
    pt_train_loader, pt_valid_loader = SplitDataLoader(
        pretrain_dataset, batch_size=const.BATCH_SIZE, drop_last=True).split()

    # Define Networks
    generator = Generator(const.VOCAB_SIZE, const.GEN_EMBED_DIM,
                          const.GEN_HIDDEN_DIM, device, args.cuda)
    discriminator = Discriminator(const.VOCAB_SIZE, const.DSCR_EMBED_DIM,
                                  const.DSCR_FILTER_LENGTHS,
                                  const.DSCR_NUM_FILTERS,
                                  const.DSCR_NUM_CLASSES, const.DSCR_DROPOUT)

    # if torch.cuda.device_count() > 1:
    # print("Using", torch.cuda.device_count(), "GPUs.")
    # generator = nn.DataParallel(generator)
    # discriminator = nn.DataParallel(discriminator)
    generator.to(device)
    discriminator.to(device)

    # set CUDA
    if args.cuda and torch.cuda.is_available():
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    ##############################################################################

    ##############################################################################
    # Pre-Training
    ##############################################################################
    # Pretrain and save Generator using MLE, Load the Pretrained generator and display training stats
    # if it already exists.
    print('#' * 80)
    print('Generator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_gen)) and op.exists(GEN_MODEL_CACHE):
        print('Loading Pretrained Generator ...')
        checkpoint = torch.load(GEN_MODEL_CACHE)
        generator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained for %d epochs.' %
              checkpoint['epochs'])
        print('::INFO:: Final Training Loss - %.5f' % checkpoint['train_loss'])
        print('::INFO:: Final Validation Loss - %.5f' %
              checkpoint['valid_loss'])
    else:
        try:
            print('Pretraining Generator with MLE ...')
            GeneratorPretrainer(generator, pt_train_loader, pt_valid_loader,
                                PT_CACHE_DIR, device, args).train()
        except KeyboardInterrupt:
            print('Stopped Generator Pretraining Early.')

    # Pretrain Discriminator on real data and data from the pretrained generator. If a pretrained Discriminator
    # already exists, load it and display its stats
    print('#' * 80)
    print('Discriminator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_dscr)) and op.exists(DSCR_MODEL_CACHE):
        print("Loading Pretrained Discriminator ...")
        checkpoint = torch.load(DSCR_MODEL_CACHE)
        discriminator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained on %d data generations.' %
              checkpoint['data_gens'])
        print('::INFO:: Model was trained for %d epochs per data generation.' %
              checkpoint['epochs_per_gen'])
        print('::INFO:: Final Loss - %.5f' % checkpoint['loss'])
    else:
        print('Pretraining Discriminator ...')
        try:
            DiscriminatorPretrainer(discriminator, rl_dataset, PT_CACHE_DIR,
                                    TEMP_DATA_DIR, device,
                                    args).train(generator)
        except KeyboardInterrupt:
            print('Stopped Discriminator Pretraining Early.')
    ##############################################################################

    ##############################################################################
    # Adversarial Training
    ##############################################################################
    print('#' * 80)
    print('Adversarial Training')
    print('#' * 80)
    AdversarialRLTrainer(generator, discriminator, rl_dataset, TEMP_DATA_DIR,
                         pt_valid_loader, device, args).train()
コード例 #22
0
ファイル: gan.py プロジェクト: dawnonme/Eureka
def train_D_With_G():
    aD = Discriminator()
    aD.cuda()

    aG = Generator()
    aG.cuda()

    optimizer_g = torch.optim.Adam(aG.parameters(), lr=0.0001, betas=(0, 0.9))
    optimizer_d = torch.optim.Adam(aD.parameters(), lr=0.0001, betas=(0, 0.9))

    criterion = nn.CrossEntropyLoss()

    n_z = 100
    n_classes = 10
    np.random.seed(352)
    label = np.asarray(list(range(10)) * 10)
    noise = np.random.normal(0, 1, (100, n_z))
    label_onehot = np.zeros((100, n_classes))
    label_onehot[np.arange(100), label] = 1
    noise[np.arange(100), :n_classes] = label_onehot[np.arange(100)]
    noise = noise.astype(np.float32)

    save_noise = torch.from_numpy(noise)
    save_noise = Variable(save_noise).cuda()
    start_time = time.time()

    # Train the model
    num_epochs = 500
    loss1 = []
    loss2 = []
    loss3 = []
    loss4 = []
    loss5 = []
    acc1 = []
    for epoch in range(0, num_epochs):

        aG.train()
        aD.train()
        avoidOverflow(optimizer_d)
        avoidOverflow(optimizer_g)
        for batch_idx, (X_train_batch,
                        Y_train_batch) in enumerate(trainloader):

            if (Y_train_batch.shape[0] < batch_size):
                continue
            # train G
            if batch_idx % gen_train == 0:
                for p in aD.parameters():
                    p.requires_grad_(False)

                aG.zero_grad()

                label = np.random.randint(0, n_classes, batch_size)
                noise = np.random.normal(0, 1, (batch_size, n_z))
                label_onehot = np.zeros((batch_size, n_classes))
                label_onehot[np.arange(batch_size), label] = 1
                noise[np.arange(batch_size), :n_classes] = label_onehot[
                    np.arange(batch_size)]
                noise = noise.astype(np.float32)
                noise = torch.from_numpy(noise)
                noise = Variable(noise).cuda()
                fake_label = Variable(torch.from_numpy(label)).cuda()

                fake_data = aG(noise)
                gen_source, gen_class = aD(fake_data)

                gen_source = gen_source.mean()
                gen_class = criterion(gen_class, fake_label)

                gen_cost = -gen_source + gen_class
                gen_cost.backward()

                optimizer_g.step()

            # train D
            for p in aD.parameters():
                p.requires_grad_(True)

            aD.zero_grad()

            # train discriminator with input from generator
            label = np.random.randint(0, n_classes, batch_size)
            noise = np.random.normal(0, 1, (batch_size, n_z))
            label_onehot = np.zeros((batch_size, n_classes))
            label_onehot[np.arange(batch_size), label] = 1
            noise[np.arange(batch_size), :n_classes] = label_onehot[np.arange(
                batch_size)]
            noise = noise.astype(np.float32)
            noise = torch.from_numpy(noise)
            noise = Variable(noise).cuda()
            fake_label = Variable(torch.from_numpy(label)).cuda()
            with torch.no_grad():
                fake_data = aG(noise)

            disc_fake_source, disc_fake_class = aD(fake_data)

            disc_fake_source = disc_fake_source.mean()
            disc_fake_class = criterion(disc_fake_class, fake_label)

            # train discriminator with input from the discriminator
            real_data = Variable(X_train_batch).cuda()
            real_label = Variable(Y_train_batch).cuda()

            disc_real_source, disc_real_class = aD(real_data)

            prediction = disc_real_class.data.max(1)[1]
            accuracy = (float(prediction.eq(real_label.data).sum()) /
                        float(batch_size)) * 100.0

            disc_real_source = disc_real_source.mean()
            disc_real_class = criterion(disc_real_class, real_label)

            gradient_penalty = calc_gradient_penalty(aD, real_data, fake_data)

            disc_cost = disc_fake_source - disc_real_source + disc_real_class + disc_fake_class + gradient_penalty
            disc_cost.backward()

            optimizer_d.step()
            loss1.append(gradient_penalty.item())
            loss2.append(disc_fake_source.item())
            loss3.append(disc_real_source.item())
            loss4.append(disc_real_class.item())
            loss5.append(disc_fake_class.item())
            acc1.append(accuracy)
            if batch_idx % 50 == 0:
                print(epoch, batch_idx, "%.2f" % np.mean(loss1),
                      "%.2f" % np.mean(loss2), "%.2f" % np.mean(loss3),
                      "%.2f" % np.mean(loss4), "%.2f" % np.mean(loss5),
                      "%.2f" % np.mean(acc1))
        # Test the model
        aD.eval()
        with torch.no_grad():
            test_accu = []
            for batch_idx, (X_test_batch,
                            Y_test_batch) in enumerate(testloader):
                X_test_batch, Y_test_batch = Variable(
                    X_test_batch).cuda(), Variable(Y_test_batch).cuda()

                with torch.no_grad():
                    _, output = aD(X_test_batch)

                prediction = output.data.max(1)[
                    1]  # first column has actual prob.
                accuracy = (float(prediction.eq(Y_test_batch.data).sum()) /
                            float(batch_size)) * 100.0
                test_accu.append(accuracy)
                accuracy_test = np.mean(test_accu)
        print('Testing', accuracy_test, time.time() - start_time)

        # save output
        with torch.no_grad():
            aG.eval()
            samples = aG(save_noise)
            samples = samples.data.cpu().numpy()
            samples += 1.0
            samples /= 2.0
            samples = samples.transpose(0, 2, 3, 1)
            aG.train()
        fig = plot(samples)
        plt.savefig('output/%s.png' % str(epoch).zfill(3), bbox_inches='tight')
        plt.close(fig)

        if (epoch + 1) % 1 == 0:
            torch.save(aG, 'tempG.model')
            torch.save(aD, 'tempD.model')

    torch.save(aG, 'generator.model')
    torch.save(aD, 'discriminator.model')
コード例 #23
0
    model_d = Discriminator()
    model_g = Generator(args.nz)
    criterion = nn.BCELoss()
    input = torch.FloatTensor(args.batch_size, INPUT_SIZE)
    noise = torch.FloatTensor(args.batch_size, (args.nz))
    
    fixed_noise = torch.FloatTensor(SAMPLE_SIZE, args.nz).normal_(0,1)
    fixed_labels = torch.zeros(SAMPLE_SIZE, NUM_LABELS)
    for i in range(NUM_LABELS):
        for j in range(SAMPLE_SIZE // NUM_LABELS):
            fixed_labels[i*(SAMPLE_SIZE // NUM_LABELS) + j, i] = 1.0
    
    label = torch.FloatTensor(args.batch_size)
    one_hot_labels = torch.FloatTensor(args.batch_size, 10)
    if args.cuda:
        model_d.cuda()
        model_g.cuda()
        input, label = input.cuda(), label.cuda()
        noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
        one_hot_labels = one_hot_labels.cuda()
        fixed_labels = fixed_labels.cuda()

    optim_d = optim.SGD(model_d.parameters(), lr=args.lr)
    optim_g = optim.SGD(model_g.parameters(), lr=args.lr)
    fixed_noise = Variable(fixed_noise)
    fixed_labels = Variable(fixed_labels)

    real_label = 1
    fake_label = 0

    for epoch_idx in range(args.epochs):
コード例 #24
0
print_every = 400

sample_size = 16

if torch.cuda.is_available():
    cuda = True
else:
    cuda = False

fixed_z = generate_z_vector(sample_size, z_size, cuda)

D.train()
G.train()

if cuda:
    D.cuda()
    G.cuda()


def train_discriminator(real_images, optimizer, batch_size, z_size):
    optimizer.zero_grad()

    if cuda:
        real_images = real_images.cuda()

    # Loss for real image
    d_real_loss = real_loss(D(real_images), cuda, smooth=True)

    # Loss for fake image
    fake_images = G(generate_z_vector(batch_size, z_size, cuda))
    d_fake_loss = fake_loss(D(fake_images), cuda)
コード例 #25
0
class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()

    def train(self):
        all_train_iter_total_loss = []
        all_train_iter_corr_loss = []
        all_train_iter_recover_loss = []
        all_train_iter_change_loss = []
        all_train_iter_gan_loss_gen = []
        all_train_iter_gan_loss_dis = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            iter_corr_loss = AverageTracker()
            iter_recover_loss = AverageTracker()
            iter_change_loss = AverageTracker()
            iter_gan_loss_gen = AverageTracker()
            iter_gan_loss_dis = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.OldLabel_generator.train()
            self.Image_generator.train()
            self.discriminator.train()
            for i, meta in enumerate(self.train_dataloader):

                image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                ), meta[2].cuda()
                recover_pred, feats = self.OldLabel_generator(
                    label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                corr_pred = self.Image_generator(image, feats)

                # -------------------
                # Train Discriminator
                # -------------------
                self.discriminator.set_requires_grad(True)
                self.optimizer_D.zero_grad()

                fake_sample = torch.cat((image, corr_pred), 1).detach()
                real_sample = torch.cat(
                    (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1)

                score_fake_d = self.discriminator(fake_sample)
                score_real = self.discriminator(real_sample)

                gan_loss_dis = self.criterion_D(pred_score=score_fake_d,
                                                real_score=score_real)
                gan_loss_dis.backward()
                self.optimizer_D.step()
                self.scheduler_D.step()

                # ---------------
                # Train Generator
                # ---------------
                self.discriminator.set_requires_grad(False)
                self.optimizer_G.zero_grad()

                score_fake = self.discriminator(
                    torch.cat((image, corr_pred), 1))

                total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G(
                    corr_pred, recover_pred, score_fake, old_label, new_label)

                total_loss.backward()
                self.optimizer_G.step()
                self.scheduler_G.step()

                iter_total_loss.update(total_loss.item())
                iter_corr_loss.update(corr_loss.item())
                iter_recover_loss.update(recover_loss.item())
                iter_change_loss.update(change_loss.item())
                iter_gan_loss_gen.update(gan_loss_gen.item())
                iter_gan_loss_dis.update(gan_loss_dis.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \
                      'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg,
                    total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item())
                print(log)

                if (i + 1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    all_train_iter_corr_loss.append(iter_corr_loss.avg)
                    all_train_iter_recover_loss.append(iter_recover_loss.avg)
                    all_train_iter_change_loss.append(iter_change_loss.avg)
                    all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg)
                    all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg)
                    iter_total_loss.reset()
                    iter_corr_loss.reset()
                    iter_recover_loss.reset()
                    iter_change_loss.reset()
                    iter_gan_loss_gen.reset()
                    iter_gan_loss_dis.reset()

                    vis.line(X=np.column_stack(
                        np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)),
                             Y=np.column_stack((all_train_iter_total_loss,
                                                all_train_iter_corr_loss,
                                                all_train_iter_recover_loss,
                                                all_train_iter_change_loss,
                                                all_train_iter_gan_loss_gen,
                                                all_train_iter_gan_loss_dis)),
                             opts={
                                 'legend': [
                                     'total_loss', 'corr_loss', 'recover_loss',
                                     'change_loss', 'gan_loss_gen',
                                     'gan_loss_dis'
                                 ],
                                 'linecolor':
                                 np.array([[255, 0, 0], [0, 255, 0],
                                           [0, 0, 255], [255, 255, 0],
                                           [0, 255, 255], [255, 0, 255]]),
                                 'title':
                                 'Train loss of generator and discriminator'
                             },
                             win='Train loss of generator and discriminator')
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.OldLabel_generator.eval()
            self.Image_generator.eval()
            self.discriminator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                    ), meta[2].cuda()
                    recover_pred, feats = self.OldLabel_generator(
                        label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                    corr_pred = self.Image_generator(image, feats)
                    preds = np.argmax(corr_pred.cpu().detach().numpy().copy(),
                                      axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)

                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(
                            np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(
                            np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(
                            os.path.join(
                                self.val_outdir, '{}epoch*{}*{}.png'.format(
                                    epoch_i, meta[3][0], meta[3][1])),
                            color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(X=np.column_stack(
                np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                     Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)),
                     opts={
                         'legend':
                         ['val epoch Overall Acc', 'val epoch Mean IoU'],
                         'linecolor': np.array([[255, 0, 0], [0, 255, 0]]),
                         'title': 'Validate Accuracy and IoU'
                     },
                     win='validate Accuracy and IoU')

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {
                'epoch': epoch_i,
                "acc": oa,
                "recall": recall,
                "iou": miou,
                'model_G_N': self.OldLabel_generator.state_dict(),
                'model_G_I': self.Image_generator.state_dict(),
                'model_D': self.discriminator.state_dict(),
                'optimizer_G': self.optimizer_G.state_dict(),
                'optimizer_D': self.optimizer_D.state_dict()
            }
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints',
                                     '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)
コード例 #26
0
def adversarial():
    # user the root logger
    logger = logging.getLogger("lan2720")
    
    argparser = argparse.ArgumentParser(add_help=False)
    argparser.add_argument('--load_path', '-p', type=str, required=True)
    # TODO: load best
    argparser.add_argument('--load_epoch', '-e', type=int, required=True)
    
    argparser.add_argument('--filter_num', type=int, required=True)
    argparser.add_argument('--filter_sizes', type=str, required=True)

    argparser.add_argument('--training_ratio', type=int, default=2)
    argparser.add_argument('--g_learning_rate', '-glr', type=float, default=0.001)
    argparser.add_argument('--d_learning_rate', '-dlr', type=float, default=0.001)
    argparser.add_argument('--batch_size', '-b', type=int, default=168)
    
    # new arguments used in adversarial
    new_args = argparser.parse_args()
    
    # load default arguments
    default_arg_file = os.path.join(new_args.load_path, 'args.pkl')
    if not os.path.exists(default_arg_file):
        raise RuntimeError('No default argument file in %s' % new_args.load_path)
    else:
        with open(default_arg_file, 'rb') as f:
            args = pickle.load(f)
    
    args.mode = 'adversarial'
    #args.d_learning_rate  = 0.0001
    args.print_every = 1
    args.g_learning_rate = new_args.g_learning_rate
    args.d_learning_rate = new_args.d_learning_rate
    args.batch_size = new_args.batch_size

    # add new arguments
    args.load_path = new_args.load_path
    args.load_epoch = new_args.load_epoch
    args.filter_num = new_args.filter_num
    args.filter_sizes = new_args.filter_sizes
    args.training_ratio = new_args.training_ratio
    


    # set up the output directory
    exp_dirname = os.path.join(args.exp_dir, args.mode, time.strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(exp_dirname)

    # set up the logger
    tqdm_logging.config(logger, os.path.join(exp_dirname, 'adversarial.log'), 
                        mode='w', silent=False, debug=True)

    # load vocabulary
    vocab, rev_vocab = load_vocab(args.vocab_file, max_vocab=args.max_vocab_size)

    vocab_size = len(vocab)

    word_embeddings = nn.Embedding(vocab_size, args.emb_dim, padding_idx=SYM_PAD)
    E = EncoderRNN(vocab_size, args.emb_dim, args.hidden_dim, args.n_layers, args.dropout_rate, bidirectional=True, variable_lengths=True)
    G = Generator(vocab_size, args.response_max_len, args.emb_dim, 2*args.hidden_dim, args.n_layers, dropout_p=args.dropout_rate)
    D = Discriminator(args.emb_dim, args.filter_num, eval(args.filter_sizes))
    
    if args.use_cuda:
        word_embeddings.cuda()
        E.cuda()
        G.cuda()
        D.cuda()

    # define optimizer
    opt_G = torch.optim.Adam(G.rnn.parameters(), lr=args.g_learning_rate)
    opt_D = torch.optim.Adam(D.parameters(), lr=args.d_learning_rate)
    
    logger.info('----------------------------------')
    logger.info('Adversarial a neural conversation model')
    logger.info('----------------------------------')

    logger.info('Args:')
    logger.info(str(args))
    
    logger.info('Vocabulary from ' + args.vocab_file)
    logger.info('vocabulary size: %d' % vocab_size)
    logger.info('Loading text data from ' + args.train_query_file + ' and ' + args.train_response_file)
   
    
    reload_model(args.load_path, args.load_epoch, word_embeddings, E, G)
    #    start_epoch = args.resume_epoch + 1
    #else:
    #    start_epoch = 0

    # dump args
    with open(os.path.join(exp_dirname, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)


    # TODO: num_epoch is old one
    for e in range(args.num_epoch):
        train_data_generator = batcher(args.batch_size, args.train_query_file, args.train_response_file)
        logger.info("Epoch: %d/%d" % (e, args.num_epoch))
        step = 0
        cur_time = time.time() 
        while True:
            try:
                post_sentences, response_sentences = train_data_generator.next()
            except StopIteration:
                # save model
                save_model(exp_dirname, e, word_embeddings, E, G, D) 
                ## evaluation
                #eval(args.valid_query_file, args.valid_response_file, args.batch_size, 
                #        word_embeddings, E, G, loss_func, args.use_cuda, vocab, args.response_max_len)
                break
            
            # prepare data
            post_ids = [sentence2id(sent, vocab) for sent in post_sentences]
            response_ids = [sentence2id(sent, vocab) for sent in response_sentences]
            posts_var, posts_length = padding_inputs(post_ids, None)
            responses_var, responses_length = padding_inputs(response_ids, args.response_max_len)
            # sort by post length
            posts_length, perms_idx = posts_length.sort(0, descending=True)
            posts_var = posts_var[perms_idx]
            responses_var = responses_var[perms_idx]
            responses_length = responses_length[perms_idx]

            if args.use_cuda:
                posts_var = posts_var.cuda()
                responses_var = responses_var.cuda()

            embedded_post = word_embeddings(posts_var)
            real_responses = word_embeddings(responses_var)

            # forward
            _, dec_init_state = E(embedded_post, input_lengths=posts_length.numpy())
            fake_responses = G(dec_init_state, word_embeddings) # [B, T, emb_size]

            prob_real = D(embedded_post, real_responses)
            prob_fake = D(embedded_post, fake_responses)
        
            # loss
            D_loss = - torch.mean(torch.log(prob_real) + torch.log(1. - prob_fake)) 
            G_loss = torch.mean(torch.log(1. - prob_fake))
            
            if step % args.training_ratio == 0:
                opt_D.zero_grad()
                D_loss.backward(retain_graph=True)
                opt_D.step()
            
            opt_G.zero_grad()
            G_loss.backward()
            opt_G.step()
            
            if step % args.print_every == 0:
                logger.info('Step %5d: D accuracy=%.2f (0.5 for D to converge) D score=%.2f (-1.38 for G to converge) (%.1f iters/sec)' % (
                    step, 
                    prob_real.cpu().data.numpy().mean(), 
                    -D_loss.cpu().data.numpy()[0], 
                    args.print_every/(time.time()-cur_time)))
                cur_time = time.time()
            step = step + 1
コード例 #27
0
ファイル: train.py プロジェクト: senior-sigan/cppn_vae_gan
def train(opt: Options):
    real_label = 1
    fake_label = 0

    netG = Generator(opt)
    netD = Discriminator(opt)
    print(netG)
    print(netD)

    netG.apply(weights_init_g)
    netD.apply(weights_init_d)

    # summary(netD, (opt.c_dim, opt.x_dim, opt.y_dim))

    dataloader = load_data(opt.data_root, opt.x_dim, opt.y_dim, opt.batch_size, opt.workers)

    x, y, r = get_coordinates(x_dim=opt.x_dim, y_dim=opt.y_dim, scale=opt.scale, batch_size=opt.batch_size)

    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

    criterion = nn.BCELoss()
    # criterion = nn.L1Loss()

    noise = torch.FloatTensor(opt.batch_size, opt.z_dim)
    ones = torch.ones(opt.batch_size, opt.x_dim * opt.y_dim, 1)
    input_ = torch.FloatTensor(opt.batch_size, opt.c_dim, opt.x_dim, opt.y_dim)
    label = torch.FloatTensor(opt.batch_size, 1)

    input_ = Variable(input_)
    label = Variable(label)
    noise = Variable(noise)

    if opt.use_cuda:
        netG = netG.cuda()
        netD = netD.cuda()
        x = x.cuda()
        y = y.cuda()
        r = r.cuda()
        ones = ones.cuda()
        criterion = criterion.cuda()
        input_ = input_.cuda()
        label = label.cuda()
        noise = noise.cuda()

    noise.data.normal_()
    fixed_seed = torch.bmm(ones, noise.unsqueeze(1))

    def _update_discriminator(data):
        # for p in netD.parameters():
        #     p.requires_grad = True  # to avoid computation
        netD.zero_grad()
        real_cpu, _ = data
        input_.data.copy_(real_cpu)
        label.data.fill_(real_label-0.1)  # use smooth label for discriminator

        output = netD(input_)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.data.mean()

        # train with fake
        noise.data.normal_()
        seed = torch.bmm(ones, noise.unsqueeze(1))

        fake = netG(x, y, r, seed)
        label.data.fill_(fake_label)
        output = netD(fake.detach())  # add ".detach()" to avoid backprop through G
        errD_fake = criterion(output, label)
        errD_fake.backward()  # gradients for fake/real will be accumulated
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        optimizerD.step()  # .step() can be called once the gradients are computed

        return fake, D_G_z1, errD, D_x

    def _update_generator(fake):
        # for p in netD.parameters():
        #     p.requires_grad = False  # to avoid computation
        netG.zero_grad()

        label.data.fill_(real_label)  # fake labels are real for generator cost

        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()  # True if backward through the graph for the second time
        D_G_z2 = output.data.mean()
        optimizerG.step()

        return D_G_z2, errG

    def _save_model(epoch):
        os.makedirs(opt.models_root, exist_ok=True)
        if epoch % 1 == 0:
            torch.save(netG.state_dict(), os.path.join(opt.models_root, "G-cppn-wgan-anime_{}.pth".format(epoch)))
            torch.save(netD.state_dict(), os.path.join(opt.models_root, "D-cppn-wgan-anime_{}.pth".format(epoch)))

    def _log(i, epoch, errD, errG, D_x, D_G_z1, D_G_z2, delta_time):
        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f Elapsed %.2f s'
              % (epoch, opt.iterations, i, len(dataloader), errD.data.item(), errG.data.item(), D_x, D_G_z1, D_G_z2,
                 delta_time))

    def _save_images(i, epoch):
        os.makedirs(opt.images_root, exist_ok=True)
        if i % 100 == 0:
            fake = netG(x, y, r, fixed_seed)
            fname = os.path.join(opt.images_root, "fake_samples_{:02}-{:04}.png".format(epoch, i))
            vutils.save_image(fake.data[0:64, :, :, :], fname, nrow=8)

    def _start():
        print("Start training")
        for epoch in range(opt.iterations):
            for i, data in enumerate(dataloader, 0):
                start_iter = time.time()

                fake, D_G_z1, errD, D_x = _update_discriminator(data)
                D_G_z2, errG = _update_generator(fake)

                end_iter = time.time()

                _log(i, epoch, errD, errG, D_x, D_G_z1, D_G_z2, end_iter - start_iter)
                _save_images(i, epoch)
            _save_model(epoch)

    _start()
コード例 #28
0
ファイル: main.py プロジェクト: chenyangh/SeqGAN-PyTorch
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # Define Networks
    generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
    discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim,
                                  d_filter_sizes, d_num_filters, d_dropout)
    target_lstm = TargetLSTM(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
    if opt.cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        target_lstm = target_lstm.cuda()
    # Generate toy data using target lstm
    print('Generating data ...')
    generate_samples(target_lstm, BATCH_SIZE, GENERATED_NUM, POSITIVE_FILE)

    # Load data from file
    gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)

    # Pretrain Generator using MLE
    gen_criterion = nn.NLLLoss(size_average=False)
    gen_optimizer = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    print('Pretrain with MLE ...')
    for epoch in range(PRE_EPOCH_NUM):
        loss = train_epoch(generator, gen_data_iter, gen_criterion,
                           gen_optimizer)
        print('Epoch [%d] Model Loss: %f' % (epoch, loss))
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
        loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
        print('Epoch [%d] True Loss: %f' % (epoch, loss))

    # Pretrain Discriminator
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    print('Pretrain Dsicriminator ...')
    for epoch in range(5):
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE)
        for _ in range(3):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                               dis_optimizer)
            print('Epoch [%d], loss: %f' % (epoch, loss))
    # Adversarial Training
    rollout = Rollout(generator, 0.8)
    print('#####################################################')
    print('Start Adeversatial Training...\n')
    gen_gan_loss = GANLoss()
    gen_gan_optm = optim.Adam(generator.parameters())
    if opt.cuda:
        gen_gan_loss = gen_gan_loss.cuda()
    gen_criterion = nn.NLLLoss(size_average=False)
    if opt.cuda:
        gen_criterion = gen_criterion.cuda()
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.sample(BATCH_SIZE, g_sequence_len)
            # construct the input to the generator, add zeros before samples and delete the last column
            zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(
                torch.cat([zeros, samples.data], dim=1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view((-1, ))
            # calculate the reward
            rewards = rollout.get_reward(samples, 16, discriminator)
            rewards = Variable(torch.Tensor(rewards))
            if opt.cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1, ))
            prob = generator.forward(inputs)
            loss = gen_gan_loss(prob, targets, rewards)
            gen_gan_optm.zero_grad()
            loss.backward()
            gen_gan_optm.step()

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
            eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
            loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
            print('Batch [%d] True Loss: %f' % (total_batch, loss))
        rollout.update_params()

        for _ in range(4):
            generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                             NEGATIVE_FILE)
            dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                        BATCH_SIZE)
            for _ in range(2):
                loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                                   dis_optimizer)
コード例 #29
0
if is_gpu_mode:
    ones_label = Variable(torch.ones(BATCH_SIZE).cuda())
    zeros_label = Variable(torch.zeros(BATCH_SIZE).cuda())
else:
    ones_label = Variable(torch.ones(BATCH_SIZE))
    zeros_label = Variable(torch.zeros(BATCH_SIZE))

if __name__ == "__main__":
    print 'main'

    gen_model = Tiramisu()
    disc_model = Discriminator()

    if is_gpu_mode:
        gen_model.cuda()
        disc_model.cuda()
        # gen_model = torch.nn.DataParallel(gen_model).cuda()
        # disc_model = torch.nn.DataParallel(disc_model).cuda()

    optimizer_gen = torch.optim.Adam(gen_model.parameters(), lr=LEARNING_RATE_GENERATOR)
    optimizer_disc = torch.optim.Adam(disc_model.parameters(), lr=LEARNING_RATE_DISCRIMINATOR)

    # read imgs
    image_buff_read_index = 0

    # pytorch style
    input_img = np.empty(shape=(BATCH_SIZE, 3, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT))
    
    answer_img = np.empty(shape=(BATCH_SIZE, 3, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT))
    
    motion_vec_img = np.empty(shape=(BATCH_SIZE, 1, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT))
コード例 #30
0
out_dir="results"
generator_AB=Generator(l=2,n_filters=8)
generator_BA=Generator(l=2,n_filters=8)
discriminator_A=Discriminator(h,w,c)
discriminator_B=Discriminator(h,w,c)
gan_loss=nn.MSELoss()
cycle_loss=nn.L1Loss()
ident_loss=nn.L1Loss()
generator_AB.apply(weight_init)
generator_BA.apply(weight_init)
discriminator_A.apply(weight_init)
discriminator_B.apply(weight_init)
if cuda:
	generator_AB=generator_AB.cuda()
	generator_BA=generator_BA.cuda()
	discriminator_A=discriminator_A.cuda()
	discriminator_B=discriminator_B.cuda()
	gan_loss.cuda()
	cycle_loss.cuda()
	ident_loss.cuda()

os.makedirs(out_dir,exist_ok=True)
patch=(1,h//2**4,w//2**4)
transforms_=[
	transforms.Resize((h,w),Image.BICUBIC),
	transforms.ToTensor(),
	transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]

train_dataloader=torch.utils.data.DataLoader(
	ImageDataset("../data/{}".format(dataset),transforms_=transforms_),
コード例 #31
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    print("======printing args========")
    print(args)
    print("=================================")

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        print("Loading bin dataset")
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    else:
        print(f"Loading raw text dataset {args.data}")
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # try to load generator model
    g_model_path = 'checkpoints/generator/best_gmodel.pt'
    if not os.path.exists(g_model_path):
        print("Start training generator!")
        train_g(args, dataset)
    assert os.path.exists(g_model_path)
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(g_model_path)
    #print(f"First dict: {pretrained_dict}")
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    #print(f"Second dict: {pretrained_dict}")
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    #print(f"model dict: {model_dict}")
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    print("Generator has successfully loaded!")

    # try to load discriminator model
    d_model_path = 'checkpoints/discriminator/best_dmodel.pt'
    if not os.path.exists(d_model_path):
        print("Start training discriminator!")
        train_d(args, dataset)
    assert os.path.exists(d_model_path)
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    model_dict = discriminator.state_dict()
    pretrained_dict = torch.load(d_model_path)
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(model_dict)

    print("Discriminator has successfully loaded!")

    #return
    print("starting main training loop")

    torch.autograd.set_detect_anomaly(True)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/joint'):
        os.makedirs('checkpoints/joint')
    checkpoints_path = 'checkpoints/joint/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(size_average=False,
                                   ignore_index=dataset.dst_dict.pad(),
                                   reduce=True)
    d_criterion = torch.nn.BCEWithLogitsLoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        # seed = args.seed + epoch_i
        # torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(itr):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when rand > 50%
            rand = random.random()
            if rand >= 0.5:
                # policy gradient training
                generator.decoder.is_testing = True
                sys_out_batch, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
                with torch.no_grad():
                    n_i = sample['net_input']['src_tokens']
                    #print(f"net input:\n{n_i}, pred: \n{prediction}")
                    reward = discriminator(
                        sample['net_input']['src_tokens'],
                        prediction)  # dataset.dst_dict.pad())
                train_trg_batch = sample['target']
                #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}")
                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()

                # oracle valid
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
            else:
                # MLE training
                #print(f"printing sample: \n{sample}")
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()
            num_update += 1

            # part II: train the discriminator
            bsz = sample['target'].size(0)
            src_sentence = sample['net_input']['src_tokens']
            # train with half human-translation and half machine translation

            true_sentence = sample['target']
            true_labels = Variable(
                torch.ones(sample['target'].size(0)).float())

            with torch.no_grad():
                generator.decoder.is_testing = True
                _, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
            fake_sentence = prediction
            fake_labels = Variable(
                torch.zeros(sample['target'].size(0)).float())

            trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
            labels = torch.cat([true_labels, fake_labels], dim=0)

            indices = np.random.permutation(2 * bsz)
            trg_sentence = trg_sentence[indices][:bsz]
            labels = labels[indices][:bsz]

            if use_cuda:
                labels = labels.cuda()

            disc_out = discriminator(src_sentence,
                                     trg_sentence)  #, dataset.dst_dict.pad())
            #print(f"disc out: {disc_out.shape}, labels: {labels.shape}")
            #print(f"labels: {labels}")
            d_loss = d_criterion(disc_out, labels.long())
            acc = torch.sum(torch.Sigmoid()
                            (disc_out).round() == labels).float() / len(labels)
            d_logging_meters['train_acc'].update(acc)
            d_logging_meters['train_loss'].update(d_loss)
            # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg,
            #                                                                            d_logging_meters['train_acc'].avg,
            #                                                                            i))
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    sample['id'] = sample['id'].cuda()
                    sample['net_input']['src_tokens'] = sample['net_input'][
                        'src_tokens'].cuda()
                    sample['net_input']['src_lengths'] = sample['net_input'][
                        'src_lengths'].cuda()
                    sample['net_input']['prev_output_tokens'] = sample[
                        'net_input']['prev_output_tokens'].cuda()
                    sample['target'] = sample['target'].cuda()

                # generator validation
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("G dev loss at batch {0}: {1:.3f}".format(
                    i, g_logging_meters['valid_loss'].avg))

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = Variable(
                    torch.ones(sample['target'].size(0)).float())

                with torch.no_grad():
                    generator.decoder.is_testing = True
                    _, prediction, _ = generator(sample)
                    generator.decoder.is_testing = False
                fake_sentence = prediction
                fake_labels = Variable(
                    torch.zeros(sample['target'].size(0)).float())

                trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
                labels = torch.cat([true_labels, fake_labels], dim=0)

                indices = np.random.permutation(2 * bsz)
                trg_sentence = trg_sentence[indices][:bsz]
                labels = labels[indices][:bsz]

                if use_cuda:
                    labels = labels.cuda()

                disc_out = discriminator(src_sentence, trg_sentence,
                                         dataset.dst_dict.pad())
                d_loss = d_criterion(disc_out, labels)
                acc = torch.sum(torch.Sigmoid()(disc_out).round() ==
                                labels).float() / len(labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg,
                #                                                                     d_logging_meters['valid_acc'].avg, i))

        torch.save(generator,
                   open(
                       checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format(
                           g_logging_meters['valid_loss'].avg, epoch_i), 'wb'),
                   pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)