示例#1
0
def train_gan(data_path, save_gen_path, save_disc_path, feedback_fn=None):
    data_loader = load_dataset(path=data_path, batch_size=128)

    gen = Generator().to('cuda')
    disc = Discriminator().to('cuda')

    criterion = torch.nn.BCEWithLogitsLoss()
    lr = 3e-4
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))

    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
    gen = gen.apply(weights_init)
    disc = disc.apply(weights_init)

    n_epochs = 5

    gen, disc = _train_loop(data_loader=data_loader,
                            gen=gen,
                            disc=disc,
                            criterion=criterion,
                            gen_opt=gen_opt,
                            disc_opt=disc_opt,
                            n_epochs=5,
                            feedback_fn=feedback_fn)

    torch.save(gen.state_dict(), save_gen_path)
    torch.save(disc.state_dict(), save_disc_path)
示例#2
0
def main(config):
    # For fast training.
    cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # running preparation
    prepareDirs(config)
    writer = SummaryWriter(config.log_dir)
    # Data preparation
    data_loader = getProperLoader(config)

    # Model Initialization
    G = Generator(config.g_conv_dim, config.c_dim, config.g_repeat_num)
    D = Discriminator(config.image_size, config.d_conv_dim, config.c_dim,
                      config.d_repeat_num)
    g_optimizer = torch.optim.Adam(G.parameters(), config.g_lr,
                                   [config.beta1, config.beta2])
    d_optimizer = torch.optim.Adam(D.parameters(), config.d_lr,
                                   [config.beta1, config.beta2])
    G.to(device)
    D.to(device)

    # Training/Test
    if config.mode == 'train':
        train(config, G, D, g_optimizer, d_optimizer, data_loader, device,
              writer)
    elif config.mode == 'test':
        test(config, G, data_loader, device)
示例#3
0
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
cri_gan = GANLoss('vanilla', 1.0, 0.0).to(device)

# textual init
vgg19 = cri_perception.vggNet
loss_layer = [1, 6, 11, 20]
loss_fns = [GramMSELoss()] * len(loss_layer)
if torch.cuda.is_available():
    loss_fns = [loss_fn.cuda() for loss_fn in loss_fns]
style_weights = [1e3 / n**2 for n in [64, 128, 256, 512]]

optimizer = torch.optim.RMSprop(
    filter(lambda p: p.requires_grad, model.parameters()), 0.0001)
optimizer_D = torch.optim.RMSprop(
    filter(lambda p: p.requires_grad, netD.parameters()), 0.0002)
print('# generator parameters:',
      sum(param.numel() for param in model.parameters()))
print('# discriminator parameters:',
      sum(param.numel() for param in netD.parameters()))
print()

for i in range(opt.start_training_step, 4):
    opt.nEpochs = training_settings[i - 1]['nEpochs']
    opt.lr = training_settings[i - 1]['lr']
    opt.step = training_settings[i - 1]['step']
    opt.lr_decay = training_settings[i - 1]['lr_decay']
    opt.lambda_db = training_settings[i - 1]['lambda_db']
    opt.gated = training_settings[i - 1]['gated']
    print(opt)
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
示例#4
0
else:
    model = Net()
    netD = Discriminator()
    mkdir_steptraing()

model = model.to(device)
netD = netD.to(device)
criterion = torch.nn.L1Loss(size_average=True)
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), 0.0001,
    [0.9, 0.999])
optimizer_D = torch.optim.Adam(
    filter(lambda p: p.requires_grad, netD.parameters()), 0.0002, [0.9, 0.999])
print()

for i in range(opt.start_training_step, 4):
    opt.nEpochs = training_settings[i - 1]['nEpochs']
    opt.lr = training_settings[i - 1]['lr']
    opt.step = training_settings[i - 1]['step']
    opt.lr_decay = training_settings[i - 1]['lr_decay']
    opt.lambda_db = training_settings[i - 1]['lambda_db']
    opt.gated = training_settings[i - 1]['gated']
    print(opt)
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        lr = adjust_learning_rate(epoch - 1)
        trainloader = CreateDataLoader(opt)
        train(trainloader, model, netD, criterion, optimizer, epoch, lr)
        if epoch % 5 == 0:
示例#5
0
netD = netD.to(device)
criterion = torch.nn.MSELoss(size_average=True)
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
cri_gan =  GANLoss('vanilla', 1.0, 0.0).to(device)

# textual init
vgg19 = cri_perception.vggNet
loss_layer = [1, 6, 11, 20]
loss_fns = [GramMSELoss()] * len(loss_layer)
if torch.cuda.is_available():
    loss_fns = [loss_fn.cuda() for loss_fn in loss_fns]
style_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512]]

optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), 0.0001)
optimizer_D = torch.optim.RMSprop(filter(lambda p: p.requires_grad, netD.parameters()), 0.0002)
print('# generator parameters:', sum(param.numel() for param in model.parameters()))
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
print()


for i in range(opt.start_training_step, 4):
    opt.nEpochs   = training_settings[i-1]['nEpochs']
    opt.lr        = training_settings[i-1]['lr']
    opt.step      = training_settings[i-1]['step']
    opt.lr_decay  = training_settings[i-1]['lr_decay']
    opt.lambda_db = training_settings[i-1]['lambda_db']
    opt.gated     = training_settings[i-1]['gated']
    print(opt)
    for epoch in range(opt.start_epoch, opt.nEpochs+1):
        lr = adjust_learning_rate(epoch-1)
示例#6
0
    # model_ESR = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, norm_type=None)
    netD = Discriminator()
    mkdir_steptraing()

# model_msrn = model_msrn.to(device)
model_GFN_deblur.to(device)
# model_GFN_sr = model_GFN_sr.to(device)
# model_DebulrGAN = model_DebulrGAN.to(device)
# model_ESR = model_ESR.to(device)
netD = netD.to(device)
# print('# MSRN parameters:', sum(param.numel() for param in model_msrn.parameters()))
print('# GFN_deblur parameters:', sum(param.numel() for param in model_GFN_deblur.parameters()))
# print('# sr parameters:', sum(param.numel() for param in model_GFN_sr.parameters()))
# print('# deblurGAN parameters:', sum(param.numel() for param in model_GFN_deblur.parameters()))
# print('# ESR parameters:', sum(param.numel() for param in model_ESR.parameters()))
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

criterion = torch.nn.L1Loss(size_average=True)
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_GFN_deblur.parameters()), 0.0001, [0.9, 0.999])
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, netD.parameters()), 0.0004, [0.9, 0.999])
print()


# for i in range(opt.start_training_step, 4):
#     opt.nEpochs   = training_settings[i-1]['nEpochs']
#     opt.lr        = training_settings[i-1]['lr']
#     opt.step      = training_settings[i-1]['step']
#     opt.lr_decay  = training_settings[i-1]['lr_decay']
#     opt.lambda_db = training_settings[i-1]['lambda_db']
示例#7
0
        netD = torch.load(opt.resumeD)
        netD.load_state_dict(netD.state_dict())
        opt.start_training_step, opt.start_epoch = which_trainingstep_epoch(opt.resume)

else:
    model = Net()
    netD = Discriminator()
    mkdir_steptraing()

model = model.to(device)
netD = netD.to(device)
criterion = torch.nn.L1Loss(size_average=True)
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 0.0001)
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, netD.parameters()), 0.0004)
print()


for i in range(opt.start_training_step, 4):
    opt.nEpochs   = training_settings[i-1]['nEpochs']
    opt.lr        = training_settings[i-1]['lr']
    opt.step      = training_settings[i-1]['step']
    opt.lr_decay  = training_settings[i-1]['lr_decay']
    opt.lambda_db = training_settings[i-1]['lambda_db']
    opt.gated     = training_settings[i-1]['gated']
    print(opt)
    for epoch in range(opt.start_epoch, opt.nEpochs+1):
        lr = adjust_learning_rate(epoch-1)
        trainloader = CreateDataLoader(opt)
        train(trainloader, model, netD, criterion, optimizer, epoch, lr)
示例#8
0
class LstmGan(object):
    def __init__(self,
                 config=None,
                 train_loader=None,
                 test_loader=None,
                 gt_loader=None):
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.gt_loader = gt_loader

    # build lstm-gan model
    def build(self):
        # build Summarizier
        self.summarizer = Summarizer(self.config.input_size,
                                     self.config.hidden_size,
                                     self.config.num_layers).cuda()
        # build Discriminator
        self.discriminator = Discriminator(self.config.input_size,
                                           self.config.hidden_size,
                                           self.config.num_layers).cuda()
        # hold and register submodules in a list
        self.model = nn.ModuleList([self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # build optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.slstm.parameters()) +
                list(self.summarizer.vae.elstm.parameters()),
                lr=self.config.sum_learning_rate)
            self.d_optimizer = optim.Adam(list(
                self.summarizer.vae.dlstm.parameters()),
                                          lr=self.config.sum_learning_rate)
            self.c_optimizer = optim.Adam(list(
                self.discriminator.parameters()),
                                          lr=self.config.dis_learning_rate)

            # set the model in training mode
            self.model.train()

            # initialize log writer
            self.writer = LogWriter(self.config.log_dir)

    """
    reconstruct loss
    Args:
        fea_h_last: given original vodieo feature, the output of the last hidden (top) layer of cLSTM, (1, hidden) = (1, 1024)
        dec_h_last: given decoded video feature, the output of the last hidden (top) layer of cLSTM, (1, hidden) = (1, 1024)
    """

    def get_reconst_loss(self, fea_h_last, dec_h_last):
        # L-2 norm
        return torch.norm(fea_h_last - dec_h_last, p=2)

    """
    summary-length regularization
    Args:
        scores: (seq_len, 1) 
    """

    def get_sparsity_loss(self, scores):
        # convert scores to 0-1
        #scores = scores.ge(0.5).float()
        print(scores)
        return torch.abs(torch.mean(scores) - self.config.summary_rate)

    """
    dpp loss
    """

    def get_dpp_loss(self, scores):
        # convert scores to 0-1
        #scores = scores.ge(0.5).float()
        seq_len = len(scores)
        dpp_loss = 0
        for i in range((seq_len - 1)):
            dpp_loss += (scores[i] * scores[i + 1] + (1 - scores[i]) *
                         (1 - scores[i + 1]))

        return torch.log(1 + dpp_loss)

    """
    gan loss of cLSTM
    Args:
        fea_prob: a scalar of original video feature
        dec_prob: a scalar of keyframe-based reconstruction video feature
        rand_dec_prob: a scalar of random-frame-based reconstruction video feature
    """

    def get_gan_loss(self, fea_prob, dec_prob, rand_dec_prob):
        gan_loss = torch.log(fea_prob) + torch.log(1 - dec_prob) + torch.log(
            1 - rand_dec_prob)

        return gan_loss

    # train model
    def train(self):
        step = 0

        for epoch_i in range(self.config.n_epochs):
            #for epoch_i in trange(self.config.n_epochs, desc = 'Epoch'):
            s_e_loss_history = []
            d_loss_history = []
            c_loss_history = []
            # one video is a batch
            for batch_i, feature in enumerate(
                    tqdm(self.train_loader,
                         desc='Batch',
                         ncols=240,
                         leave=False)):

                # feature: (1, seq_len, input_size) -> (seq_len, 1, 2048)
                feature = feature.view(-1, 1, self.config.input_size)

                # from cpu tensor to gpu Variable
                feature = Variable(feature).cuda()
                """ train sLSTM and eLSTM """
                if self.config.detail_flag:
                    tqdm.write('------Training sLSTM and eLSTM------')

                # decoded: decoded feature generated by dLSTM
                scores, decoded = self.summarizer(feature)

                #scores.masked_fill_(scores >= 0.5, 1)
                #scores.masked_fill_(scores < 0.5, 0)

                # shape of feature and decoded: (seq_len, 1, 2048)
                # shape of fea_h_last and dec_h_last: (1, hidden_size) = (1, 2048)
                fea_h_last, fea_prob = self.discriminator(feature)
                dec_h_last, dec_prob = self.discriminator(decoded)

                tqdm.write('fea_prob: %.3f, dec_prob: %.3f}' %
                           (fea_prob.data[0], dec_prob.data[0]))

                # reconstruction loss
                reconst_loss = self.get_reconst_loss(fea_h_last, dec_h_last)
                tqdm.write('reconst_loss: %.3f' % reconst_loss.data[0])

                # sparsity loss
                sparsity_loss = self.get_sparsity_loss(scores)
                tqdm.write('sparsity_loss: %.3f' % sparsity_loss.data[0])

                # dpp loss
                dpp_loss = self.get_dpp_loss(scores)
                tqdm.write('diversity_loss: %.3f' % dpp_loss.data[0])

                # minimize
                s_e_loss = reconst_loss + sparsity_loss + dpp_loss

                self.s_e_optimizer.zero_grad()
                s_e_loss.backward(retain_graph=True)
                self.s_e_optimizer.step()

                # add to loss history
                s_e_loss_history.append(s_e_loss.data.cpu())
                """ train dLSTM """
                if self.config.detail_flag:
                    tqdm.write('------Training dLSTM------')

                # randomly select a subset of frames
                _, rand_decoded = self.summarizer(feature,
                                                  random_score_flag=True)
                # shape of rand_dec_h_last: (1, hidden_size) = (1, 2048)
                rand_dec_h_last, rand_dec_prob = self.discriminator(
                    rand_decoded)
                # gan loss
                gan_loss = self.get_gan_loss(fea_prob, dec_prob, rand_dec_prob)
                tqdm.write('gan_loss: %.3f' % gan_loss.data[0])

                # minimize
                d_loss = reconst_loss + gan_loss

                self.d_optimizer.zero_grad()
                d_loss.backward(retain_graph=True)
                self.d_optimizer.step()

                # add to loss history
                d_loss_history.append(d_loss.data.cpu())
                """ train cLSTM """
                if batch_i > self.config.dis_start_batch:
                    if self.config.detail_flag:
                        tqdm.write('------Training cLSTM------')

                # maximize
                c_loss = -1 * self.get_gan_loss(fea_prob, dec_prob,
                                                rand_dec_prob)

                self.c_optimizer.zero_grad()
                c_loss.backward()
                self.c_optimizer.step()

                # add to loss history
                c_loss_history.append(c_loss.data.cpu())

                if self.config.detail_flag:
                    tqdm.write('------plotting------')

                self.writer.update_loss(reconst_loss.cpu().data.numpy(), step,
                                        'reconst_loss')
                self.writer.update_loss(sparsity_loss.cpu().data.numpy(), step,
                                        'sparsity_loss')
                self.writer.update_loss(gan_loss.cpu().data.numpy(), step,
                                        'gan_loss')

                self.writer.update_loss(fea_prob.cpu().data.numpy(), step,
                                        'fea_prob')
                self.writer.update_loss(dec_prob.cpu().data.numpy(), step,
                                        'dec_prob')
                self.writer.update_loss(rand_dec_prob.cpu().data.numpy(), step,
                                        'rand_dec_prob')

                # bacth over
                step += 1

            # epoch
            s_e_loss = torch.stack(s_e_loss_history).mean()
            d_loss = torch.stack(d_loss_history).mean()
            c_loss = torch.stack(c_loss_history).mean()

            # plot
            if self.config.detail_flag:
                tqdm.write('------plotting------')
            self.writer.update_loss(s_e_loss, epoch_i, 's_e_epoch')
            self.writer.update_loss(d_loss, epoch_i, 'd_loss_epoch')
            self.writer.update_loss(c_loss, epoch_i, 'c_loss_epoch')

            # save parameters at checkpoint
            model_path = str(self.config.model_save_dir) + '/' + 'model.pkl'
            tqdm.write('------save model parameters at %s' % model_path)
            torch.save(self.model.state_dict(), model_path)

    def test(self):
        model_path = str(self.config.model_save_dir) + '/' + 'model.pkl'
        # load model
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()

        ground_truth = self.gt_loader

        pred = {}
        cnt = 0
        precision = []
        for sample_i, feature in enumerate(
                tqdm(self.test_loader, desc='Evaluate', leave=False)):
            # feautre: (seq_len, 1, 2048)
            feature = feature.view(-1, 1, self.config.input_size)
            scores = self.summarizer.slstm(
                Variable(feature.cuda(), volatile=True))
            # convert scores to a list
            scores = scores.data.cpu().numpy().flatten()
            pred[sample_i] = scores

            print(scores)

            # sort socres in descending order, and return index
            idx_sorted_scores = sorted(range(len(scores)),
                                       key=lambda k: scores[k],
                                       reverse=True)

            num_keyframes = int(len(scores) * self.config.summary_rate)
            idx_keyframes = idx_sorted_scores[:num_keyframes]
            # sort key frames index in ascending order
            idx_keyframes = sorted(idx_keyframes)

            # precision for a single video
            single_precision = float(
                len(set(ground_truth[cnt]) & set(idx_keyframes))) / float(
                    len(ground_truth[cnt]))

            # add single video precision to the whole precision list
            precision.append(single_precision)

            cnt += 1

        # average precision for the dataset
        avg_precision = sum(precision) / float(len(precision))
示例#9
0
        model.load_state_dict(model.state_dict())
        netD = torch.load(opt.resumeD)
        netD.load_state_dict(netD.state_dict())
        # opt.start_epoch = which_trainingstep_epoch(opt.resume)

else:
    model = Generator()
    netD = Discriminator()
    mkdir_steptraing()

model = model.to(device)
netD = netD.to(device)
criterion = torch.nn.MSELoss(size_average=True)
criterion = criterion.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
optimizer_D = optim.Adam(netD.parameters(), lr=0.0002)
print()

# opt.start_training_step = 2
# for i in range(opt.start_training_step, 4):
#     opt.nEpochs   = training_settings[i-1]['nEpochs']
#     opt.lr        = training_settings[i-1]['lr']
#     opt.step      = training_settings[i-1]['step']
#     opt.lr_decay  = training_settings[i-1]['lr_decay']
#     opt.lambda_db = training_settings[i-1]['lambda_db']
#     opt.gated     = training_settings[i-1]['gated']
#     print(opt)

opt.start_epoch = 2
opt.nEpochs = 1000
# trainloader = CreateDataLoader(opt)
示例#10
0
        return NotImplementedError(
            'learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler


# update learning rate (called once every epoch)
def update_learning_rate(scheduler, optimizer):
    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
    print('learning rate = %.7f' % lr)


optimizer_g = torch.optim.Adam(mysegnet.parameters(),
                               lr=opt.g_lr,
                               betas=(opt.beta1, 0.999))
optimizer_d = torch.optim.Adam(mydisnet.parameters(),
                               lr=opt.d_lr,
                               betas=(opt.beta1, 0.999))
net_g_scheduler = get_scheduler(optimizer_g, opt)
net_d_scheduler = get_scheduler(optimizer_d, opt)

alpha = opt.alpha

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    # train
    for iteration, ((batchX, batchY), (batchY1)) in enumerate(
            zip(training_iterator, train_unpair_iterator)):
        # forward
        # input data
        batchX = batchX.unsqueeze(1).to(device='cuda').float()
        batchY = batchY.unsqueeze(1).to(device='cuda').float()