def disc_gp(self, x):
        with torch.no_grad():
            # Spatial features
            cnn_feats = self.conv(x)
            bsz, c, h, w = cnn_feats.size()

            # Discriminator
            disc_inp = cnn_feats.view(c, -1).t()

        # Gradient Penalty
        gp = utils.calc_gradient_penalty(self.disc, disc_inp)

        return cnn_feats, gp
示例#2
0
文件: train.py 项目: yuanluw/DCGAN
def train(train_data, g_net, d_net, criterion, g_optimizer, d_optimizer, epoch, logger):
    g_net.train()
    d_net.train()
    g_losses = AverageMeter()
    d_losses = AverageMeter()

    for i, (img, _) in enumerate(train_data):
        mini_batch = img.size()[0]
        # train discriminator
        x_ = Variable(img.cuda())
        z_ = torch.randn(mini_batch, config.G_input_dim).view(-1, config.G_input_dim, 1, 1)
        z_ = Variable(z_.cuda())

        gen_img = g_net(z_).detach()
        gradient_penalty = calc_gradient_penalty(d_net, x_.data, gen_img.data)
        # print(gradient_penalty, -d_net(x_).mean(), d_net(gen_img).mean())
        d_loss = -d_net(x_).mean() + d_net(gen_img).mean() + config.lambda_gp * gradient_penalty

        # bp
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        # for p in d_net.parameters():
        #     p.data.clamp_(-config.clip, config.clip)

        if i % config.n_critic == 0:
            z_ = torch.randn(mini_batch, config.G_input_dim).view(-1, config.G_input_dim, 1, 1)
            z_ = Variable(z_.cuda())
            gen_img = g_net(z_)
            g_loss = -d_net(gen_img).mean()
            # bp
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            g_losses.update(g_loss.item())
            d_losses.update(d_loss.item())

            if i % config.print_freq == 0:
                logger.info('Epoch: [{0}][{1}][{2}]\t'
                            'g_loss {g_loss.val:.4f} ({g_loss.avg:.4f})\t'
                            'd_loss {d_loss.val:.3f} ({d_loss.avg:.3f})'
                            .format(epoch, i, len(train_data), g_loss=g_losses, d_loss=d_losses))

    return g_losses.avg, d_losses.avg
示例#3
0
def train(vocabs, char_vocab, tag_vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
    """
    train_sets, dev_sets, test_sets: dict[lang] -> AmazonDataset
    For unlabeled langs, no train_sets are available
    """
    # dataset loaders
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters, d_unlabeled_iters = {}, {}, {}
    dev_loaders, test_loaders = {}, {}
    my_collate = utils.sorted_collate if opt.model=='lstm' else utils.unsorted_collate
    for lang in opt.langs:
        train_loaders[lang] = DataLoader(train_sets[lang],
                opt.batch_size, shuffle=True, collate_fn = my_collate)
        train_iters[lang] = iter(train_loaders[lang])
    for lang in opt.dev_langs:
        dev_loaders[lang] = DataLoader(dev_sets[lang],
                opt.batch_size, shuffle=False, collate_fn = my_collate)
        test_loaders[lang] = DataLoader(test_sets[lang],
                opt.batch_size, shuffle=False, collate_fn = my_collate)
    for lang in opt.all_langs:
        if lang in opt.unlabeled_langs:
            uset = unlabeled_sets[lang]
        else:
            # for labeled langs, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset([train_sets[lang], unlabeled_sets[lang]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[lang]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[lang]
            else:
                raise Exception(f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}')
        unlabeled_loaders[lang] = DataLoader(uset,
                opt.batch_size, shuffle=True, collate_fn = my_collate)
        unlabeled_iters[lang] = iter(unlabeled_loaders[lang])
        d_unlabeled_iters[lang] = iter(unlabeled_loaders[lang])

    # embeddings
    emb = MultiLangWordEmb(vocabs, char_vocab, opt.use_wordemb, opt.use_charemb).to(opt.device)
    # models
    F_s = None
    F_p = None
    C, D = None, None
    num_experts = len(opt.langs)+1 if opt.expert_sp else len(opt.langs)
    if opt.model.lower() == 'lstm':
        if opt.shared_hidden_size > 0:
            F_s = LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.shared_hidden_size,
                                       opt.word_dropout, opt.dropout, opt.bdrnn)
        if opt.private_hidden_size > 0:
            if not opt.concat_sp:
                assert opt.shared_hidden_size == opt.private_hidden_size, "shared dim != private dim when using add_sp!"
            F_p = nn.Sequential(
                    LSTMFeatureExtractor(opt.total_emb_size, opt.F_layers, opt.private_hidden_size,
                            opt.word_dropout, opt.dropout, opt.bdrnn),
                    MixtureOfExperts(opt.MoE_layers, opt.private_hidden_size,
                            len(opt.langs), opt.private_hidden_size,
                            opt.private_hidden_size, opt.dropout, opt.MoE_bn, False)
                    )
    else:
        raise Exception(f'Unknown model architecture {opt.model}')

    if opt.C_MoE:
        C = SpMixtureOfExperts(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp,
                num_experts, opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab),
                opt.mlp_dropout, opt.C_bn)
    else:
        C = SpMlpTagger(opt.C_layers, opt.shared_hidden_size, opt.private_hidden_size, opt.concat_sp,
                opt.shared_hidden_size + opt.private_hidden_size, len(tag_vocab),
                opt.mlp_dropout, opt.C_bn)
    if opt.shared_hidden_size > 0 and opt.n_critic > 0:
        if opt.D_model.lower() == 'lstm':
            d_args = {
                'num_layers': opt.D_lstm_layers,
                'input_size': opt.shared_hidden_size,
                'hidden_size': opt.shared_hidden_size,
                'word_dropout': opt.D_word_dropout,
                'dropout': opt.D_dropout,
                'bdrnn': opt.D_bdrnn,
                'attn_type': opt.D_attn
            }
        elif opt.D_model.lower() == 'cnn':
            d_args = {
                'num_layers': 1,
                'input_size': opt.shared_hidden_size,
                'hidden_size': opt.shared_hidden_size,
                'kernel_num': opt.D_kernel_num,
                'kernel_sizes': opt.D_kernel_sizes,
                'word_dropout': opt.D_word_dropout,
                'dropout': opt.D_dropout
            }
        else:
            d_args = None

        if opt.D_model.lower() == 'mlp':
            D = MLPLanguageDiscriminator(opt.D_layers, opt.shared_hidden_size,
                    opt.shared_hidden_size, len(opt.all_langs), opt.loss, opt.D_dropout, opt.D_bn)
        else:
            D = LanguageDiscriminator(opt.D_model, opt.D_layers,
                    opt.shared_hidden_size, opt.shared_hidden_size,
                    len(opt.all_langs), opt.D_dropout, opt.D_bn, d_args)
    if opt.use_data_parallel:
        F_s, C, D = nn.DataParallel(F_s).to(opt.device) if F_s else None, nn.DataParallel(C).to(opt.device), nn.DataParallel(D).to(opt.device) if D else None
    else:
        F_s, C, D = F_s.to(opt.device) if F_s else None, C.to(opt.device), D.to(opt.device) if D else None
    if F_p:
        if opt.use_data_parallel:
            F_p = nn.DataParallel(F_p).to(opt.device)
        else:
            F_p = F_p.to(opt.device)
    # optimizers
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(*map(list,
        [emb.parameters(), F_s.parameters() if F_s else [], \
        C.parameters(), F_p.parameters() if F_p else []]))),
        lr=opt.learning_rate,
        weight_decay=opt.weight_decay)
    if D:
        optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate, weight_decay=opt.D_weight_decay)

    # testing
    if opt.test_only:
        log.info(f'Loading model from {opt.model_save_file}...')
        if F_s:
            F_s.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                f'netF_s.pth')))
        for lang in opt.all_langs:
            F_p.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                f'net_F_p.pth')))
        C.load_state_dict(torch.load(os.path.join(opt.model_save_file,
            f'netC.pth')))
        if D:
            D.load_state_dict(torch.load(os.path.join(opt.model_save_file,
                f'netD.pth')))

        log.info('Evaluating validation sets:')
        acc = {}
        log.info(dev_loaders)
        log.info(vocabs)
        for lang in opt.all_langs:
            acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for lang in opt.all_langs:
            test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average test accuracy: {avg_test_acc}')
        return {'valid': acc, 'test': test_acc}

    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    epochs_since_decay = 0
    # lambda scheduling
    if opt.lambd > 0 and opt.lambd_schedule:
        opt.lambd_orig = opt.lambd
    num_iter = int(utils.gmean([len(train_loaders[l]) for l in opt.langs]))
    # adapt max_epoch
    if opt.max_epoch > 0 and num_iter * opt.max_epoch < 15000:
        opt.max_epoch = 15000 // num_iter
        log.info(f"Setting max_epoch to {opt.max_epoch}")
    for epoch in range(opt.max_epoch):
        emb.train()
        if F_s:
            F_s.train()
        C.train()
        if D:
            D.train()
        if F_p:
            F_p.train()
            
        # lambda scheduling
        if hasattr(opt, 'lambd_orig') and opt.lambd_schedule:
            if epoch == 0:
                opt.lambd = opt.lambd_orig
            elif epoch == 5:
                opt.lambd = 10 * opt.lambd_orig
            elif epoch == 15:
                opt.lambd = 100 * opt.lambd_orig
            log.info(f'Scheduling lambda = {opt.lambd}')

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        gate_correct = defaultdict(int)
        c_gate_correct = defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        for i in tqdm(range(num_iter), ascii=True):
            # D iterations
            if opt.shared_hidden_size > 0:
                utils.freeze_net(emb)
                utils.freeze_net(F_s)
                utils.freeze_net(F_p)
                utils.freeze_net(C)
                utils.unfreeze_net(D)
                # WGAN n_critic trick since D trains slower
                n_critic = opt.n_critic
                if opt.wgan_trick:
                    if opt.n_critic>0 and ((epoch==0 and i<25) or i%500==0):
                        n_critic = 100

                for _ in range(n_critic):
                    D.zero_grad()
                    loss_d = {}
                    lang_features = {}
                    # train on both labeled and unlabeled langs
                    for lang in opt.all_langs:
                        # targets not used
                        d_inputs, _ = utils.endless_get_next_batch(
                                unlabeled_loaders, d_unlabeled_iters, lang)
                        d_inputs, d_lengths, mask, d_chars, d_char_lengths = d_inputs
                        d_embeds = emb(lang, d_inputs, d_chars, d_char_lengths)
                        shared_feat = F_s((d_embeds, d_lengths))
                        if opt.grad_penalty != 'none':
                            lang_features[lang] = shared_feat.detach()
                        if opt.D_model.lower() == 'mlp':
                            d_outputs = D(shared_feat)
                            # if token-level D, we can reuse the gate label generator
                            d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
                            d_total += torch.sum(d_lengths).item()
                        else:
                            d_outputs = D((shared_feat, d_lengths))
                            d_targets = utils.get_lang_label(opt.loss, lang, len(d_lengths))
                            d_total += len(d_lengths)
                        # D accuracy
                        _, pred = torch.max(d_outputs, -1)
                        # d_total += len(d_lengths)
                        d_correct += (pred==d_targets).sum().item()
                        if opt.use_data_parallel:
                            l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs),
                                    d_targets.view(-1), ignore_index=-1)
                        else:
                            l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
                                    d_targets.view(-1), ignore_index=-1)

                        l_d.backward()
                        loss_d[lang] = l_d.item()
                    # gradient penalty
                    if opt.grad_penalty != 'none':
                        gp = utils.calc_gradient_penalty(D, lang_features,
                                onesided=opt.onesided_gp, interpolate=(opt.grad_penalty=='wgan'))
                        gp.backward()
                    optimizerD.step()

            # F&C iteration
            utils.unfreeze_net(emb)
            if opt.use_wordemb and opt.fix_emb:
                for lang in emb.langs:
                    emb.wordembs[lang].weight.requires_grad = False
            if opt.use_charemb and opt.fix_charemb:
                emb.charemb.weight.requires_grad = False
            utils.unfreeze_net(F_s)
            utils.unfreeze_net(F_p)
            utils.unfreeze_net(C)
            utils.freeze_net(D)
            emb.zero_grad()
            if F_s:
                F_s.zero_grad()
            if F_p:
                F_p.zero_grad()
            C.zero_grad()
            # optimizer.zero_grad()
            for lang in opt.langs:
                inputs, targets = utils.endless_get_next_batch(
                        train_loaders, train_iters, lang)
                inputs, lengths, mask, chars, char_lengths = inputs
                bs, seq_len = inputs.size()
                embeds = emb(lang, inputs, chars, char_lengths)
                shared_feat, private_feat = None, None
                if opt.shared_hidden_size > 0:
                    shared_feat = F_s((embeds, lengths))
                if opt.private_hidden_size > 0:
                    private_feat, gate_outputs = F_p((embeds, lengths))
                if opt.C_MoE:
                    c_outputs, c_gate_outputs = C((shared_feat, private_feat))
                else:
                    c_outputs = C((shared_feat, private_feat))
                # targets are padded with -1
                l_c = functional.nll_loss(c_outputs.view(bs*seq_len, -1),
                        targets.view(-1), ignore_index=-1)
                # gate loss
                if F_p:
                    gate_targets = utils.get_gate_label(gate_outputs, lang, mask, False)
                    l_gate = functional.cross_entropy(gate_outputs.view(bs*seq_len, -1),
                            gate_targets.view(-1), ignore_index=-1)
                    l_c += opt.gate_loss_weight * l_gate
                    _, gate_pred = torch.max(gate_outputs.view(bs*seq_len, -1), -1)
                    gate_correct[lang] += (gate_pred == gate_targets.view(-1)).sum().item()
                if opt.C_MoE and opt.C_gate_loss_weight > 0:
                    c_gate_targets = utils.get_gate_label(c_gate_outputs, lang, mask, opt.expert_sp)
                    _, c_gate_pred = torch.max(c_gate_outputs.view(bs*seq_len, -1), -1)
                    if opt.expert_sp:
                        l_c_gate = functional.binary_cross_entropy_with_logits(
                                mask.unsqueeze(-1) * c_gate_outputs, c_gate_targets)
                        c_gate_correct[lang] += torch.index_select(c_gate_targets.view(bs*seq_len, -1),
                                -1, c_gate_pred.view(bs*seq_len)).sum().item()
                    else:
                        l_c_gate = functional.cross_entropy(c_gate_outputs.view(bs*seq_len, -1),
                                c_gate_targets.view(-1), ignore_index=-1)
                        c_gate_correct[lang] += (c_gate_pred == c_gate_targets.view(-1)).sum().item()
                    l_c += opt.C_gate_loss_weight * l_c_gate
                l_c.backward()
                _, pred = torch.max(c_outputs, -1)
                total[lang] += torch.sum(lengths).item()
                correct[lang] += (pred == targets).sum().item()

            # update F with D gradients on all langs
            if D:
                for lang in opt.all_langs:
                    inputs, _ = utils.endless_get_next_batch(
                            unlabeled_loaders, unlabeled_iters, lang)
                    inputs, lengths, mask, chars, char_lengths = inputs
                    embeds = emb(lang, inputs, chars, char_lengths)
                    shared_feat = F_s((embeds, lengths))
                    # d_outputs = D((shared_feat, lengths))
                    if opt.D_model.lower() == 'mlp':
                        d_outputs = D(shared_feat)
                        # if token-level D, we can reuse the gate label generator
                        d_targets = utils.get_gate_label(d_outputs, lang, mask, False, all_langs=True)
                    else:
                        d_outputs = D((shared_feat, lengths))
                        d_targets = utils.get_lang_label(opt.loss, lang, len(lengths))
                    if opt.use_data_parallel:
                        l_d = functional.nll_loss(d_outputs.view(-1, D.module.num_langs),
                                d_targets.view(-1), ignore_index=-1)
                    else:
                        l_d = functional.nll_loss(d_outputs.view(-1, D.num_langs),
                                d_targets.view(-1), ignore_index=-1)
                    if opt.lambd > 0:
                        l_d *= -opt.lambd
                    l_d.backward()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch+1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.langs))
        log.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.langs]))
        log.info('Gate accuracy:')
        log.info('\t'.join([str(100.0*gate_correct[d]/total[d]) for d in opt.langs]))
        log.info('Tagger Gate accuracy:')
        log.info('\t'.join([str(100.0*c_gate_correct[d]/total[d]) for d in opt.langs]))
        log.info('Evaluating validation sets:')
        acc = {}
        for lang in opt.dev_langs:
            acc[lang] = evaluate(f'{lang}_dev', dev_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_acc = sum([acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for lang in opt.dev_langs:
            test_acc[lang] = evaluate(f'{lang}_test', test_loaders[lang], vocabs[lang], tag_vocab,
                    emb, lang, F_s, F_p, C)
        avg_test_acc = sum([test_acc[d] for d in opt.dev_langs]) / len(opt.dev_langs)
        log.info(f'Average test accuracy: {avg_test_acc}')

        if avg_acc > best_avg_acc:
            epochs_since_decay = 0
            log.info(f'New best average validation accuracy: {avg_acc}')
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.model_save_file, 'options.pkl'), 'wb') as ouf:
                pickle.dump(opt, ouf)
            if F_s:
                torch.save(F_s.state_dict(),
                        '{}/netF_s.pth'.format(opt.model_save_file))
            torch.save(emb.state_dict(),
                    '{}/net_emb.pth'.format(opt.model_save_file))
            if F_p:
                torch.save(F_p.state_dict(),
                        '{}/net_F_p.pth'.format(opt.model_save_file))
            torch.save(C.state_dict(),
                    '{}/netC.pth'.format(opt.model_save_file))
            if D:
                torch.save(D.state_dict(),
                        '{}/netD.pth'.format(opt.model_save_file))
        else:
            epochs_since_decay += 1
            if opt.lr_decay < 1 and epochs_since_decay >= opt.lr_decay_epochs:
                epochs_since_decay = 0
                old_lr = optimizer.param_groups[0]['lr']
                optimizer.param_groups[0]['lr'] = old_lr * opt.lr_decay
                log.info(f'Decreasing LR to {old_lr * opt.lr_decay}')

    # end of training
    log.info(f'Best average validation accuracy: {best_avg_acc}')
    return best_acc
示例#4
0
文件: train.py 项目: w3user/SegDGAN
def train(NetG, NetD, optimizerG, optimizerD, dataloader, epoch):
    total_dice = 0
    total_g_loss = 0
    total_g_loss_dice = 0
    total_g_loss_bce = 0
    total_d_loss = 0
    total_d_loss_penalty = 0
    NetG.train()
    NetD.train()

    for i, data in enumerate(dataloader, 1):
        # train D
        optimizerD.zero_grad()
        NetD.zero_grad()
        for p in NetG.parameters():
            p.requires_grad = False
        for p in NetD.parameters():
            p.requires_grad = True

        input, target = Variable(data[0]), Variable(data[1])
        input = input.float()
        target = target.float()

        if use_cuda:
            input = input.cuda()
            target = target.cuda()

        output = NetG(input)
        output = F.sigmoid(output)
        output = output.detach()

        input_img = input.clone()
        output_masked = input_img * output
        if use_cuda:
            output_masked = output_masked.cuda()

        result = NetD(output_masked)

        target_masked = input_img * target
        if use_cuda:
            target_masked = target_masked.cuda()

        target_D = NetD(target_masked)
        loss_mac = -torch.mean(torch.abs(result - target_D))
        loss_mac.backward()

        # D net gradient_penalty
        batch_size = target_masked.size(0)
        gradient_penalty = utils.calc_gradient_penalty(NetD, target_masked,
                                                       output_masked,
                                                       batch_size, use_cuda,
                                                       input.shape)
        gradient_penalty.backward()
        optimizerD.step()

        # train G
        optimizerG.zero_grad()
        NetG.zero_grad()
        for p in NetG.parameters():
            p.requires_grad = True
        for p in NetD.parameters():
            p.requires_grad = False

        output = NetG(input)
        output = F.sigmoid(output)

        target_dice = target.view(-1).long()
        output_dice = output.view(-1)
        loss_dice = utils.dice_loss(output_dice, target_dice)

        output_masked = input_img * output
        if use_cuda:
            output_masked = output_masked.cuda()
        result = NetD(output_masked)

        target_G = NetD(target_masked)
        loss_G = torch.mean(torch.abs(result - target_G))
        loss_G_joint = loss_G + loss_dice
        loss_G_joint.backward()
        optimizerG.step()

        total_dice += 1 - loss_dice.data[0]
        total_g_loss += loss_G_joint.data[0]
        total_g_loss_dice += loss_dice.data[0]
        total_g_loss_bce += loss_G.data[0]
        total_d_loss += loss_mac.data[0]
        total_d_loss_penalty += gradient_penalty.data[0]

    for p in NetG.parameters():
        p.requires_grad = True
    for p in NetD.parameters():
        p.requires_grad = True

    size = len(dataloader)

    epoch_dice = total_dice / size
    epoch_g_loss = total_g_loss / size
    epoch_g_loss_dice = total_g_loss_dice / size
    epoch_g_loss_bce = total_g_loss_bce / size

    epoch_d_loss = total_d_loss / size
    epoch_d_loss_penalty = total_d_loss_penalty / size

    print_format = [
        epoch, conf.epochs, epoch_dice * 100, epoch_g_loss, epoch_g_loss_dice,
        epoch_g_loss_bce, epoch_d_loss, epoch_d_loss_penalty
    ]
    print('===> Training step {}/{} \tepoch_dice: {:.5f}'
          '\tepoch_g_loss: {:.5f} \tepoch_g_loss_dice: {:.5f}'
          '\tepoch_g_loss_bce: {:.5f} \tepoch_d_loss: {:.5f}'
          '\tepoch_d_loss_penalty: {:.5f}'.format(*print_format))
示例#5
0
                output = D_a(real_a).to(opt.device)
                errD_real = -1 * (2 + opt.lambda_self) * output.mean()  # -a
                errD_real.backward(retain_graph=True)

                output_a = D_a(mix_g_a.detach())
                output_a2 = D_a(fake_a.detach())
                if opt.lambda_self > 0.0:
                    output_a3 = D_a(self_a.detach())
                    output_a3 = output_a3.mean()
                else:
                    output_a3 = 0
                errD_fake_a = output_a.mean() + output_a2.mean(
                ) + opt.lambda_self * output_a3
                errD_fake_a.backward(retain_graph=True)

                gradient_penalty_a = calc_gradient_penalty(
                    D_a, real_a, mix_g_a, opt.lambda_grad, opt.device)
                gradient_penalty_a += calc_gradient_penalty(
                    D_a, real_a, fake_a, opt.lambda_grad, opt.device)
                if opt.lambda_self > 0.0:
                    gradient_penalty_a += opt.lambda_self * calc_gradient_penalty(
                        D_a, real_a, self_a, opt.lambda_grad, opt.device)
                gradient_penalty_a.backward(retain_graph=True)

                #############################
                ####      Train D_b      ####
                #############################

                D_b.zero_grad()

                output = D_b(real_b).to(opt.device)
                errD_real = -1 * (2 + opt.lambda_self) * output.mean()  # -a
示例#6
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
示例#7
0
            # compute real data loss for discriminator
            d_loss_real = D(images)
            d_loss_real = d_loss_real.mean()
            d_loss_real.backward(fake_labels)

            # compute fake data loss for discriminator
            noise = make_variable(torch.randn(
                params.batch_size, params.z_dim, 1, 1).normal_(0, 1),
                volatile=True)
            fake_images = make_variable(G(noise).data)
            d_loss_fake = D(fake_images.detach())
            d_loss_fake = d_loss_fake.mean()
            d_loss_fake.backward(real_labels)

            # compute gradient penalty
            gradient_penalty = calc_gradient_penalty(
                D, images.data, fake_images.data)
            gradient_penalty.backward()

            # optimize weights of discriminator
            d_loss = - d_loss_real + d_loss_fake + gradient_penalty
            d_optimizer.step()

        ##########################
        # (2) training generator #
        ##########################
        # avoid to compute gradients for D
        for p in D.parameters():
            p.requires_grad = False

        # zero grad for optimizer of generator
        g_optimizer.zero_grad()
示例#8
0
        D_real_spk.backward(mone)

        # train with real from other speakers
        D_real_nspk = BETA * D_net(real_data_nspk).mean()
        D_real_nspk.backward(one)

        # train with fake data
        noise = autograd.Variable(torch.randn(BATCH_SIZE, 128),
                                  volatile=True).cuda()
        fake_data = autograd.Variable(G_net(noise).data)
        D_fake = D_net(fake_data).mean()
        D_fake.backward(one)

        # train with gradient penalty
        gradient_penalty = calc_gradient_penalty(D_net, real_data_spk.data,
                                                 fake_data.data, BATCH_SIZE,
                                                 LAMBDA)
        gradient_penalty.backward()

        D_cost = D_fake + D_real_nspk - D_real_spk + gradient_penalty
        Wasserstein_D = D_real_spk - D_fake
        D_optimizer.step()

    ############################
    # (2) Update G network
    ###########################
    for p in D_net.parameters():
        p.requires_grad = False  # to avoid computation
    G_net.zero_grad()

    noise = autograd.Variable(torch.randn(BATCH_SIZE, 128)).cuda()
示例#9
0
    def train(self):
        """Training Discriminator with Generator."""
        start_time = time.time()
        n_classes = 10

        # before epoch training loop starts
        loss1 = []
        loss2 = []
        loss3 = []
        loss4 = []
        loss5 = []
        acc1 = []

        np.random.seed(352)
        label = np.asarray(list(range(10)) * 10)
        noise = np.random.normal(0, 1, (100, self.n_z))
        label_onehot = np.zeros((100, n_classes))
        label_onehot[np.arange(100), label] = 1
        noise[np.arange(100), :n_classes] = label_onehot[np.arange(100)]
        noise = noise.astype(np.float32)

        save_noise = torch.from_numpy(noise)
        if self.cuda:
            save_noise = save_noise.cuda()
        save_noise = Variable(save_noise)

        # Train the model
        for epoch in range(self.start_epoch, self.start_epoch + self.epochs):
            # turn models to `train` mode
            self.aG.train()
            self.aD.train()

            for batch_idx, (X_train_batch, Y_train_batch) in enumerate(self.trainloader):

                if Y_train_batch.shape[0] < self.batch_size:
                    continue

                # train G
                if batch_idx % self.gen_train == 0:
                    for p in self.aD.parameters():
                        p.requires_grad_(False)

                    self.aG.zero_grad()

                    label = np.random.randint(0, n_classes, self.batch_size)
                    noise = np.random.normal(0, 1, (self.batch_size, self.n_z))
                    label_onehot = np.zeros((self.batch_size, n_classes))
                    label_onehot[np.arange(self.batch_size), label] = 1
                    noise[np.arange(self.batch_size), :n_classes] = label_onehot[np.arange(
                        self.batch_size)]
                    noise = noise.astype(np.float32)
                    noise = torch.from_numpy(noise)
                    if self.cuda:
                        noise = noise.cuda()
                    noise = Variable(noise)
                    # noise = Variable(noise).cuda()
                    if self.cuda:
                        fake_label = Variable(torch.from_numpy(label)).cuda()
                    else:
                        fake_label = Variable(torch.from_numpy(label))
                    # fake_label = Variable(torch.from_numpy(label)).cuda()

                    fake_data = self.aG(noise)
                    gen_source, gen_class = self.aD(fake_data)

                    gen_source = gen_source.mean()
                    gen_class = self.criterion(gen_class, fake_label)

                    gen_cost = -gen_source + gen_class
                    gen_cost.backward()

                    for group in self.optimizer_g.param_groups:
                        for p in group['params']:
                            state = self.optimizer_g.state[p]
                            if('step' in state and state['step'] >= 1024):
                                state['step'] = 1000
                    self.optimizer_g.step()

                # train D
                for p in self.aD.parameters():
                    p.requires_grad_(True)
                self.aD.zero_grad()

                # train discriminator with input from generator
                label = np.random.randint(0, n_classes, self.batch_size)
                noise = np.random.normal(0, 1, (self.batch_size, self.n_z))
                label_onehot = np.zeros((self.batch_size, n_classes))
                label_onehot[np.arange(self.batch_size), label] = 1
                noise[np.arange(self.batch_size), :n_classes] = label_onehot[np.arange(
                    self.batch_size)]
                noise = noise.astype(np.float32)
                noise = torch.from_numpy(noise)
                if self.cuda:
                    noise = noise.cuda()
                noise = Variable(noise)

                if self.cuda:
                    fake_label = Variable(torch.from_numpy(label)).cuda()
                else:
                    fake_label = Variable(torch.from_numpy(label))

                with torch.no_grad():
                    fake_data = self.aG(noise)

                disc_fake_source, disc_fake_class = self.aD(fake_data)

                disc_fake_source = disc_fake_source.mean()
                disc_fake_class = self.criterion(disc_fake_class, fake_label)

                # train discriminator with input from the discriminator
                if self.cuda:
                    real_data, real_label = X_train_batch.cuda(), Y_train_batch.cuda()
                else:
                    real_data, real_label = X_train_batch, Y_train_batch
                real_data, real_label = Variable(
                    real_data), Variable(real_label)

                disc_real_source, disc_real_class = self.aD(real_data)

                prediction = disc_real_class.data.max(1)[1]
                accuracy = (float(prediction.eq(real_label.data).sum()
                                  ) / float(self.batch_size)) * 100.0

                disc_real_source = disc_real_source.mean()
                disc_real_class = self.criterion(disc_real_class, real_label)

                gradient_penalty = calc_gradient_penalty(
                    self.aD, real_data, fake_data, self.batch_size, self.cuda)

                disc_cost = disc_fake_source - disc_real_source + \
                    disc_real_class + disc_fake_class + gradient_penalty
                disc_cost.backward()

                for group in self.optimizer_d.param_groups:
                    for p in group['params']:
                        state = self.optimizer_d.state[p]
                        if('step' in state and state['step'] >= 1024):
                            state['step'] = 1000
                self.optimizer_d.step()

                # within the training loop
                loss1.append(gradient_penalty.item())
                loss2.append(disc_fake_source.item())
                loss3.append(disc_real_source.item())
                loss4.append(disc_real_class.item())
                loss5.append(disc_fake_class.item())
                acc1.append(accuracy)
                if batch_idx % 50 == 0:
                    print("Trainig epoch: {} | Accuracy: {} | Batch: {} | Gradient penalty: {} | Discriminator fake source: {} | Discriminator real source: {} | Discriminator real class: {} | Discriminator fake class: {}".format(
                        epoch, np.mean(acc1), batch_idx, np.mean(loss1), np.mean(loss2), np.mean(loss3), np.mean(loss4), np.mean(loss5)))

            # Test the model
            self.aD.eval()
            with torch.no_grad():
                test_accu = []
                for batch_idx, (X_test_batch, Y_test_batch) in enumerate(self.testloader):
                    if self.cuda:
                        X_test_batch, Y_test_batch = X_test_batch.cuda(), Y_test_batch.cuda()
                    X_test_batch, Y_test_batch = Variable(
                        X_test_batch), Variable(Y_test_batch)

                    with torch.no_grad():
                        _, output = self.aD(X_test_batch)

                    # first column has actual prob.
                    prediction = output.data.max(1)[1]
                    accuracy = (
                        float(prediction.eq(Y_test_batch.data).sum()) / float(self.batch_size)) * 100.0
                    test_accu.append(accuracy)
                    accuracy_test = np.mean(test_accu)
            # print('Testing', accuracy_test, time.time() - start_time)
            print("Testing accuracy: {} | Eplased time: {}".format(
                accuracy_test, time.time() - start_time))

            # save output
            with torch.no_grad():
                self.aG.eval()
                samples = self.aG(save_noise)
                samples = samples.data.cpu().numpy()
                samples += 1.0
                samples /= 2.0
                samples = samples.transpose(0, 2, 3, 1)
                self.aG.train()

            fig = plot(samples)
            if not os.path.isdir('../output'):
                os.mkdir('../output')
            plt.savefig('../output/%s.png' %
                        str(epoch).zfill(3), bbox_inches='tight')
            plt.close(fig)

            if (epoch + 1) % 1 == 0:
                torch.save(self.aG, '../model/tempG.model')
                torch.save(self.aD, '../model/tempD.model')
示例#10
0
            # Discriminateur D
            optimizerD.zero_grad()
            outputTrue = netD(x_cuda, alpha=(1-alpha_value) if alpha else -1)
            # lossDT = F.binary_cross_entropy_with_logits(outputTrue, real_label)
            lossDT = -torch.mean(outputTrue)

            # with false label
            outputG = netG(Variable(noise))
            outputFalse = netD(outputG.detach(), alpha=(1-alpha_value) if alpha else -1)

            # lossDF = F.binary_cross_entropy_with_logits(outputFalse, fake_label)
            lossDF = torch.mean(outputFalse)
            dTrue.append(F.sigmoid(outputTrue).data.mean())
            dFalse.append(F.sigmoid(outputFalse).data.mean())

            gradient_penalty = utils.calc_gradient_penalty(netD, x_cuda, outputG, batch_size=batchsize, lda=10, view=x_cuda.size())
            (lossDT+lossDF+gradient_penalty).backward()
            optimizerD.step()

            ldf += lossDF
            ldt += lossDT

            # Generateur
            optimizerG.zero_grad()
            outputG = netG(noise, alpha=(1-alpha_value) if alpha else -1)
            outputD = netD(outputG, alpha=(1-alpha_value) if alpha else -1)
            # lossG = F.binary_cross_entropy_with_logits(outputD, real_label)
            lossG = -torch.mean(outputD)
            lossG.backward()
            optimizerG.step()
def train():
    TRAINING_ITERATIONS = 100000  #@param {type:"number"}
    MAX_CONTEXT_POINTS = 50  #@param {type:"number"}
    PLOT_AFTER = 100  #10000 #@param {type:"number"}
    HIDDEN_SIZE = 300  #@param {type:"number"}
    MODEL_TYPE = 'ANP'  #@param ['NP','ANP']
    ATTENTION_TYPE = 'multihead'  #@param ['uniform','laplace','dot_product','multihead']
    batch_size = 64
    X_SIZE = 1
    Y_SIZE = 1
    vocab = pickle.load(open("vocab.pkl", "rb"))
    test_sentences = [
        "Two men seated at an open air restaurant",
        "flowers in a pot sitting on a cement wall",
        "a vase and lids are sitting on a table",
        "a teddy bear that is sitting next to some item on a table",
        "a plant in a vase by the window",
        "a young girl is similing and she has food around her on a table"
    ]

    dataset_train = get_coco_loader("./resized_small_train2014/",
                                    "./annotations/captions_train2014.json",
                                    vocab=vocab,
                                    transform=None,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4)

    dataset_test = dataset_train  # we will need to build out the test dataset soon
    # Sizes of the layers of the MLPs for the encoders and decoder
    # The final output layer of the decoder outputs two values, one for the mean and
    # one for the variance of the prediction at the target location
    latent_encoder_output_sizes = [HIDDEN_SIZE] * 4
    num_latents = HIDDEN_SIZE
    deterministic_encoder_output_sizes = [HIDDEN_SIZE] * 4
    decoder_output_sizes = [32] * 2 + [2]
    use_deterministic_path = True
    xy_size = X_SIZE + Y_SIZE

    # # ANP with multihead attention
    # if MODEL_TYPE == 'ANP':
    #     attention = Attention(rep='mlp', x_size=X_SIZE, r_size=deterministic_encoder_output_sizes[-1], output_sizes=[HIDDEN_SIZE]*2,
    #                         att_type=ATTENTION_TYPE).to(device) # CHANGE: rep was originally 'mlp'
    # # NP - equivalent to uniform attention
    # elif MODEL_TYPE == 'NP':
    #     attention = Attention(rep='identity', x_size=None, output_sizes=None, att_type='uniform').to(device)
    # else:
    #     raise NameError("MODEL_TYPE not among ['ANP,'NP']")

    # # Define the model
    # print("num_latents: {}, latent_encoder_output_sizes: {}, deterministic_encoder_output_sizes: {}, decoder_output_sizes: {}".format(
    #     num_latents, latent_encoder_output_sizes, deterministic_encoder_output_sizes, decoder_output_sizes))
    # decoder_input_size = 2 * HIDDEN_SIZE + X_SIZE
    # model_wass = LatentModel(X_SIZE, Y_SIZE, latent_encoder_output_sizes, num_latents,
    #                     decoder_output_sizes, use_deterministic_path,
    #                     deterministic_encoder_output_sizes, attention, loss_type="wass").to(device)
    encoder = Encoder().to(device)
    decoder = Decoder().to(device)
    critic = Critic().to(device)

    optimizer_critic = torch.optim.Adam(critic.parameters())

    optimizer = torch.optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()))
    for epoch in range(10):
        progress = tqdm(enumerate(dataset_train))
        total_loss = 0
        count = 0

        for i, (images, targets, lengths) in progress:
            try:
                optimizer.zero_grad()
                gen_loss = 0
                for instance in range(batch_size):
                    image, target, length = images[instance].to(
                        device), targets[instance], lengths[instance]
                    sentence = get_sentence(target, vocab)
                    vectors = torch.Tensor([
                        nlp(word).vector for word in sentence
                        if "<" not in word
                    ]).to(device)
                    vectors = vectors.unsqueeze(0)
                    image = image.unsqueeze(0)
                    r = encoder(vectors, image)

                    decoder_input = torch.cat(
                        (r.repeat(1, vectors.shape[1], 1), vectors.float()),
                        -1)
                    out = decoder(decoder_input)
                    fake_image = out.view(32, 32, 3)

                    disc_fake = critic(fake_image)
                    disc_fake.backward()
                    gen_loss = -disc_fake
                optimizer.step()

                for t in range(5):
                    optimizer_critic.zero_grad()
                    loss = 0
                    for instance in range(batch_size):
                        image, target, length = images[instance].to(
                            device), targets[instance], lengths[instance]
                        sentence = get_sentence(target, vocab)
                        vectors = torch.Tensor([
                            nlp(word).vector for word in sentence
                            if "<" not in word
                        ]).to(device)
                        vectors = vectors.unsqueeze(0)
                        image = image.unsqueeze(0)
                        r = encoder(vectors, image)
                        decoder_input = torch.cat((r.repeat(
                            1, vectors.shape[1], 1), vectors.float()), -1)
                        out = decoder(decoder_input)
                        fake_image = out.view(32, 32, 3)
                        # fake_image = fake_image.transpose(1,0).transpose(2,1)

                        disc_real = critic(image)
                        disc_fake = critic(fake_image)
                        gradient_penalty = utils.calc_gradient_penalty(
                            critic, image, fake_image)
                        loss = disc_fake - disc_real + gradient_penalty
                        loss.backward()
                        w_dist = disc_real - disc_fake
                    optimizer_critic.step()
                progress.set_description("E{} - L{:.4f}".format(
                    epoch, w_dist.item()))

                with open("encoder.pkl", "wb") as of:
                    pickle.dump(encoder, of)

                with open("decoder.pkl", "wb") as of:
                    pickle.dump(decoder, of)

                with open("critic.pkl", "wb") as of:
                    pickle.dump(critic, of)
            except Exception as e:
                print(e)
                continue
            try:
                if i % 100 == 0:
                    with torch.no_grad():
                        decoder_input = torch.cat((r.repeat(
                            1, vectors.shape[1], 1), vectors.float()), -1)
                        out = decoder(decoder_input)
                        fake_image = out.view(32, 32, 3)
                        plt.imshow(fake_image.detach().cpu())
                        plt.xlabel(" ".join([
                            x for x in sentence
                            if x not in {"<end>", "<pad>", "<start>", "<unk>"}
                        ]),
                                   wrap=True)
                        plt.tight_layout()
                        plt.savefig("{}generated{}.png".format(
                            i + 1, sentence[1]))
                        plt.close()
            except:
                continue
        print("done")
示例#12
0
            D.zero_grad()
            D_optimizer.zero_grad()

            real_pair = torch.cat((imgs, g_truth), dim=1)

            d_real = D(real_pair)
            d_real = d_real.mean()
            d_real.backward(mone)

            fake_pair = torch.cat((imgs, G(imgs).detach()), dim=1)

            d_fake = D(fake_pair)
            d_fake = d_fake.mean()
            d_fake.backward(one)

            gradient_penalty = calc_gradient_penalty(D, real_pair.data, fake_pair.data)
            gradient_penalty.backward()

            D_optimizer.step()

            Wasserstein_D = d_real- d_fake
            D_losses.append(Wasserstein_D.item())
    # train the generator
    for idx, (imgs, g_truth) in tqdm.tqdm(enumerate(train_loader)):
        mini_batch = imgs.size()[0]

        y_real_ = torch.ones(mini_batch)
        y_fake_ = torch.zeros(mini_batch)

        imgs, g_truth, y_real_, y_fake_ = Variable(imgs.cuda()), Variable(g_truth.cuda()), Variable( y_real_.cuda()), Variable(y_fake_.cuda())
        #imgs, g_truth, y_real_, y_fake_ = Variable(imgs), Variable(g_truth), Variable(
示例#13
0
        disc_fake_class = criterion(disc_fake_class, fake_label)

        # calculate discriminator loss with real data
        real_data = Variable(x).cuda()
        real_label = Variable(y).cuda()

        disc_real_source, disc_real_class = aD(real_data)

        prediction = disc_real_class.data.max(1)[1]
        accuracy = (float(prediction.eq(real_label.data).sum()) /
                    float(batch_size)) * 100.0

        disc_real_source = disc_real_source.mean()
        disc_real_class = criterion(disc_real_class, real_label)

        gradient_penalty = calc_gradient_penalty(aD, real_data, fake_data,
                                                 batch_size)

        disc_cost = disc_fake_source - disc_real_source + disc_real_class + disc_fake_class + gradient_penalty
        disc_cost.backward()

        optimizer_d.step()
        """
        Append losses and print
        """
        loss1.append(gradient_penalty.item())
        loss2.append(disc_fake_source.item())
        loss3.append(disc_real_source.item())
        loss4.append(disc_real_class.item())
        loss5.append(disc_fake_class.item())
        acc1.append(accuracy)
        if batch_idx % 50 == 0:
示例#14
0
    def deepinversion_improved(self, use_generator      = False, \
                                     discrete_label     = True,  \
                                     noisify_network    = 0.0, \
                                     knowledge_distill  = 0.0, \
                                     mutual_info        = 0.0, \
                                     batchnorm_transfer = 0.0, \
                                     use_discriminator  = 0.0, \
                                     n_iters = 100):
        tb = SummaryWriter()
        if use_generator == True:
            z = torch.randn((self.n_samples, self.latent_dim),
                            requires_grad=False,
                            device=self.device,
                            dtype=torch.float)
            if discrete_label == True:
                y_gt = torch.randint(0,
                                     2, (self.n_samples, self.label_dim),
                                     dtype=torch.float,
                                     device=self.device)
            else:
                y_gt = torch.cuda.FloatTensor(self.n_samples,
                                              self.label_dim).uniform_(0, 1)
            x = self.net_gen(z, y_gt)
            if mutual_info > 0.0:
                ''' declare the optimizer for the encoder network '''
                optimizer = torch.optim.Adam(list(self.net_gen.parameters()) +
                                             list(self.net_enc.parameters()),
                                             lr=self.lr)
            else:
                optimizer = torch.optim.Adam(self.net_gen.parameters(),
                                             lr=self.lr)
        else:
            x = torch.randn((self.n_samples, 2),
                            requires_grad=True,
                            device=self.device,
                            dtype=torch.float)
            if discrete_label == True:
                y_gt = torch.randint(0,
                                     2, (self.n_samples, self.label_dim),
                                     dtype=torch.float,
                                     device=self.device)
            else:
                y_gt = torch.cuda.FloatTensor(self.n_samples,
                                              self.label_dim).uniform_(0, 1)
            optimizer = torch.optim.Adam([x], lr=self.lr)

        #update name of output
        self.imgname = self.imgname + "_gen%d" % (use_generator)
        ''' declare the optimizer for the student network '''
        optimizer_std = torch.optim.Adam(self.net_std.parameters(),
                                         lr=self.classifier_lr)

        if self.device == 'cuda':
            x_np = x.cpu().detach().clone().numpy()
        else:
            x_np = x.detach().clone().numpy()

        fig, ax = self.setup_plot_progress(x_np)

        total_loss = []

        # set for testing with batchnorm
        self.net.eval()

        ## Create hooks for feature statistics
        loss_bn_feature_layers = []
        if use_generator == True and use_discriminator > 0.0:
            nets_dis = []
            nets_dis_params = []

        for module in self.net.modules():
            if isinstance(module, nn.BatchNorm1d):
                loss_bn_feature_layers.append(bn1dfeathook(module))
                if use_generator == True and use_discriminator > 0.0:
                    net_dis = netdis(module.running_mean.shape[0],
                                     self.n_hidden, 1).cuda()
                    net_dis.apply(weights_init)
                    nets_dis.append(net_dis)
                    nets_dis_params += list(net_dis.parameters())

        if use_generator == True and use_discriminator > 0.0:
            self.optimizer_dis = torch.optim.Adam(nets_dis_params,
                                                  lr=self.lr,
                                                  betas=(0.5, 0.9))

        ## Create hooks for feature statistics for generator
        if use_generator == True and batchnorm_transfer > 0.0:
            loss_bn_feature_layers_gen = []
            self.compute_loss_bn_gen(loss_bn_feature_layers_gen)

        for it in range(n_iters):
            self.net.zero_grad()
            self.net_gen.zero_grad()
            self.net_std.zero_grad()
            self.net_enc.zero_grad()
            optimizer.zero_grad()
            optimizer_std.zero_grad()

            if use_generator == True:
                ''' randomly sampling latent and labels '''
                z = torch.randn((self.n_samples, self.latent_dim),
                                requires_grad=False,
                                device=self.device,
                                dtype=torch.float)
                y_gt = torch.randint(0,
                                     2, (self.n_samples, self.label_dim),
                                     dtype=torch.float,
                                     device=self.device)

            if use_generator == True:
                ''' generating samples with generator '''
                x = self.net_gen(z, y_gt)
            '''
            **********************************************************************
            To optimize the generated samples or training the generator
            **********************************************************************
            '''
            if noisify_network > 0.0:
                ''' adding noise into the pre-trained classifier '''
                weight = noisify_network * (n_iters - it) / n_iters
                self.net, orig_params = add_noise_to_net(self.net,
                                                         weight=weight,
                                                         noise_type='uniform')

            if it == 0:
                self.imgname = self.imgname + "_nosify%0.3f" % (
                    noisify_network)

            y_pd = self.net(x)
            ''' main loss (cross-entropy loss) '''
            loss_main = self.loss_func(y_pd, y_gt)
            ''' l2 regularization '''
            loss_l2 = torch.norm(x.view(-1, self.n_input_dim), dim=1).mean()
            ''' batch-norm regularization '''
            rescale = [1. for _ in range(len(loss_bn_feature_layers))]
            loss_bn = sum([
                mod.r_feature * rescale[idx]
                for (idx, mod) in enumerate(loss_bn_feature_layers)
            ])
            ''' total loss '''
            if use_generator == True and use_discriminator > 0.0:
                bn_w = 0.05
            else:
                bn_w = 1.0

            loss = loss_main + 0.005 * loss_l2 + bn_w * loss_bn

            if knowledge_distill > 0.0:
                ''' knowledge distillation (teacher-student) based regularization '''
                y_st = self.net_std(x)
                #loss_kd = 1 - self.loss_func(y_st, y_pd.detach())
                loss_kd = knowledge_distill_loss(y_pd.detach(), y_st)
                loss = loss + knowledge_distill * loss_kd

            if it == 0:
                self.imgname = self.imgname + "_kdistill%0.3f" % (
                    knowledge_distill)

            if mutual_info > 0.0:
                ''' mutual information constraint '''
                ze = self.net_enc(x)
                loss_mi = ((z - ze)**2).mean()

                zdiv = torch.randn((self.n_samples, self.latent_dim),
                                   requires_grad=False,
                                   device=self.device,
                                   dtype=torch.float)
                xdiv = self.net_gen(zdiv, y_gt)
                loss_div = diveristy_loss(z, x, zdiv, xdiv)

                loss = loss + mutual_info * loss_mi + 0.1 * mutual_info * loss_div

            if it == 0:
                self.imgname = self.imgname + "_minfo%0.3f" % (mutual_info)

            if use_generator == True and batchnorm_transfer > 0.0:
                ''' batch-norm transfer loss '''
                rescale_gen = [
                    1. for _ in range(len(loss_bn_feature_layers_gen))
                ]
                loss_bn_gen = sum([
                    mod.r_feature * rescale_gen[idx]
                    for (idx, mod) in enumerate(loss_bn_feature_layers_gen)
                ])
                loss = loss + batchnorm_transfer * loss_bn_gen

            if it == 0:
                self.imgname = self.imgname + "_btransfer%0.3f" % (
                    batchnorm_transfer)

            if use_generator == True and use_discriminator > 0.0:
                # train the generator on features
                loss_g = 0
                # traing the generator on features
                for (idx, mod) in enumerate(loss_bn_feature_layers):
                    nets_dis[idx].zero_grad()
                    # frozen the gradient for the discriminator
                    for p in nets_dis[idx].parameters():
                        p.requires_grad = False  # to avoid computation
                    feat_fake = mod.feat_fake.cuda()
                    d_fake = nets_dis[idx](feat_fake)
                    loss_g = loss_g - d_fake.mean()
                loss = loss + use_discriminator * loss_g

            if use_generator == True and it == 0:
                self.imgname = self.imgname + "_discriminator%0.3f" % (
                    use_discriminator)

            loss.backward()
            optimizer.step()

            if it % 100 == 0:
                tb.add_scalar("Total loss: ", loss, it)
                tb.add_scalar("Loss batchnorm", loss_bn, it)
                tb.add_histogram("Input", x, it)
                # tb.add_histogram("Input/gradients", x.grad, it)
                for name, param in self.net_gen.named_parameters():
                    tb.add_histogram(name, param.data, it)
                    tb.add_histogram(name + "/gradients", param.grad, it)

            if noisify_network > 0.0:
                ''' reset the network's parameters '''
                reset_params(self.net, orig_params)

            if knowledge_distill > 0.0:
                '''
               **********************************************************************
               To update the student network
               **********************************************************************
               '''
                if use_generator == True:
                    ''' generating samples with generator '''
                    x = self.net_gen(z, y_gt)

                y_pd = self.net(x)
                y_st = self.net_std(x)
                #loss_kd = self.loss_func(y_st, y_pd.detach())
                loss_kd = 1. - knowledge_distill_loss(y_pd.detach(), y_st)
                loss_kd.backward()
                optimizer_std.step()
            ''' store the main loss to plot on the figure '''
            total_loss.append(loss.item())

            if use_generator == True and use_discriminator > 0.0:
                # traing the discriminator on features
                for _ in range(5):
                    loss_d = 0
                    x = self.net_gen(z, y_gt)
                    self.net(x)
                    for (idx, mod) in enumerate(loss_bn_feature_layers):
                        nets_dis[idx].zero_grad()
                        for p in nets_dis[idx].parameters(
                        ):  # reset requires_grad
                            p.requires_grad = True
                        feat_real = mod.feat_real.cuda()
                        feat_fake = mod.feat_fake.cuda()
                        d_real = nets_dis[idx](feat_real)
                        d_fake = nets_dis[idx](feat_fake)
                        penalty = calc_gradient_penalty(nets_dis[idx],
                                                        feat_real,
                                                        feat_fake,
                                                        LAMBDA=1.0)
                        loss_d = loss_d + use_discriminator * (
                            d_fake.mean() - d_real.mean() + penalty)
                    loss_d.backward()
                    self.optimizer_dis.step()

            if it % 10 == 0:
                print('-- iter %d --' % (it))
                print('target loss: %f' % (loss_main.item()))
                print('l2-norm loss: %f' % (loss_l2.item()))
                print('batchnorm loss: %f' % (loss_bn.item()))
                if knowledge_distill > 0.0:
                    print('distillation loss: %f' % (loss_bn.item()))
                if mutual_info > 0.0:
                    print('mutual information / diversity losses: %f / %f' %
                          (loss_mi.item(), loss_div.item()))
                if batchnorm_transfer > 0.0:
                    print('batch-norm transfer loss: %f ' %
                          (loss_bn_gen.item()))
                if use_generator == True and use_discriminator > 0.0:
                    print('loss d / loss g: %f / %f' %
                          (loss_d.item(), loss_g.item()))
                print('total loss: %f' % (loss.item()))
                ''' realtime plot '''
                ax[0].plot(total_loss, c='b')
                fig.canvas.draw()

        if self.device == 'cuda':
            x_np = x.cpu().detach().numpy()
        else:
            x_np = x.detach().numpy()
        tb.close()
        ax[1].scatter(x_np[:, 0], x_np[:, 1], c='b', cmap=plt.cm.Accent)
        plt.savefig(self.basedir + "%s.png" % (self.imgname))
        plt.show()
示例#15
0
            # train with fake
            theta = minR + 2 * (np.pi - minR) * torch.rand(
                batch_size, 17, device=device)

            with torch.no_grad():
                z_pred = netG(real_data)
            fake_data = utils.rotate_and_project(real_data, z_pred, theta)

            D_fake = netD(fake_data)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # gradient penalty
            GP = utils.calc_gradient_penalty(netD,
                                             real_data,
                                             fake_data,
                                             LAMBDA=LAMBDA,
                                             device=device)
            GP.backward()
            GP = GP.item()

            loss_D = (D_fake - D_real + GP).item()
            WD = (D_real - D_fake).item()
            optimizerD.step()

        else:
            ############################
            # (1) Update G network: maximize E[D(G(x))]
            ###########################
            netG.zero_grad()
            for p in netD.parameters():