Exemple #1
0
    def update(self, mi, batch, stats):
        ''' Actor critic model update.
        Feed stats for later summarization.

        Args:
            mi(`ModelInterface`): mode interface used
            batch(dict): batch of data. Keys in a batch:
                ``s``: state,
                ``r``: immediate reward,
                ``terminal``: if game is terminated
            stats(`Stats`): Feed stats for later summarization.
        '''
        m = mi["model"]
        args = self.args
        Q_node = args.Q_node
        a_node = args.a_node

        T = batch["s"].size(0)

        state_curr = m(batch.hist(T - 1))
        Q = state_curr[Q_node].squeeze().data
        V = Q.max(1)
        self.discounted_reward.setR(V, stats)

        err = None

        for t in range(T - 2, -1, -1):
            bht = batch.hist(t)
            state_curr = m.forward(bht)

            # go through the sample and get the rewards.
            Q = state_curr[Q_node].squeeze()
            a = state_curr[a_node].squeeze()

            R = self.discounted_reward.feed(dict(
                r=batch["r"][t], terminal=batch["terminal"][t]),
                                            stats=stats)

            # Then you want to match Q value here.
            # Q: batchsize * #action.
            Q_sel = Q.gather(1, a.view(-1, 1)).squeeze()
            err = add_err(err, nn.L2Loss(Q_sel, Variable(R)))

        # stats["cost"].feed(err.data[0] / (T - 1))
        stats["cost"].feed(err.item() / (T - 1))
        err.backward()
Exemple #2
0
    def forward(self, text, melspec, probable_path, text_lengths, mel_lengths,
                criterion, stage):
        text = text[:, :text_lengths.max().item()]
        melspec = melspec[:, :, :mel_lengths.max().item()]
        if stage == 0:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss, _, _, _ = criterion(mu_sigma, melspec, text_lengths,
                                          mel_lengths)
            return mdn_loss

        elif stage == 1:
            encoder_input = self.Prenet(text)
            hidden_states, _ = FFT_lower(encoder_input, text_lengths)
            mel_out = get_melspec(hidden_states, probable_path, mel_lengths)
            fft_loss = nn.L1Loss()(mel_out, melspec)
            return fft_loss

        elif stage == 2:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss, log_prob_matrix, _, _ = criterion(
                mu_sigma, melspec, text_lengths, mel_lengths)

            probable_path = self.viterbi(log_prob_matrix, text_lengths,
                                         mel_lengths)  # B, T
            mel_out = self.get_melspec(hidden_states, probable_path,
                                       mel_lengths)
            fft_loss = nn.L1Loss()(mel_out, melspec)
            return mdn_loss + fft_loss

        elif stage == 3:
            encoder_input = self.Prenet(text)
            durations_out = self.get_duration(encoder_input.data,
                                              text_lengths)  # gradient cut

            path_onehot = 1.0 * (hidden_states.new_tensor(
                torch.arange(probable_path.max() + 1)).unsqueeze(-1)
                                 == probable_path)  # B, T, L
            target_durations = self.align2duration(path_onehot, mel_mask)
            duration_loss = nn.L2Loss(torch.log(durations_out),
                                      torch.log(target_durations))
            return duration_loss
def nn_reg_evaluation(loader,
                      net,
                      use_gpu=False,
                      plot_file=None,
                      loss='l1-norm'):

    if loss == 'l1':
        criterion = nn.L1Loss()
    elif loss == 'l2':
        criterion = nn.L2Loss()
    elif loss == 'l1-norm':

        class L1NormaLoss(nn.Module):
            def __call__(self, input, target):
                abs_dif = torch.abs(input - target)
                abs_norm = abs_dif / torch.abs(target)
                sum_abs = torch.sum(abs_norm)
                return sum_abs

        criterion = L1NormaLoss()

    # Dataset loss
    ds_loss = 0.
    ds_size = len(loader.dataset)

    for batch in loader:

        # get the inputs and features
        features, labels = batch['sample'], batch['label']

        # wrap them in Variable
        if use_gpu:
            features, labels = Variable(features.cuda()), Variable(
                labels.cuda())
        else:
            features, labels = Variable(features), Variable(labels)

        outputs = net(features)
        #outputs = torch.log(net(features) + 1.)
        labels = torch.log(labels + 1.)

        ds_loss = ds_loss + criterion(outputs, labels)

    return ds_loss / ds_size
Exemple #4
0
    def forward(self, text, melspec, durations, text_lengths, mel_lengths,
                criterion, stage):
        text = text[:, :text_lengths.max().item()]
        melspec = melspec[:, :, :mel_lengths.max().item()]

        if stage == 0:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss, _ = criterion(mu_sigma, melspec, text_lengths,
                                    mel_lengths)
            return mdn_loss

        elif step == 1:
            encoder_input = self.Prenet(text)
            hidden_states, _ = FFT_lower(encoder_input, text_lengths)
            mel_out = get_melspec(hidden_states, durations, mel_lengths)
            fft_loss = nn.L1Loss()(mel_out, melspec)
            return fft_loss

        elif step == 2:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss = criterion(mu_sigma, melspec, mel_lengths)

            mel_out = self.get_melspec(hidden_states, durations, mel_lengths)
            fft_loss = nn.L1Loss()(mel_out, melspec)
            return mdn_loss + fft_loss

        elif step == 3:
            encoder_input = self.Prenet(text)
            durations_out = self.get_duration(encoder_input, text_lengths)
            duration_loss = nn.L2Loss(torch.log(durations_out),
                                      torch.log(durations))
            return duration_loss
Exemple #5
0
                                drop_last=True)
    #test_loader = RBCLoader(f_path + "/test", batch_size=arg.batch_size,
    #                         transform=None,shuffle=True)

    if arg.model == "fusion":
        net = Fusionnet(arg.in_channel, arg.out_channel, arg.ngf, arg.clamp)
    elif arg.model == "unet":
        net = Unet2D(feature_scale=arg.feature_scale)
    elif arg.model == "unet_sh":
        net = UnetSH2D(arg.sh_size, feature_scale=arg.feature_scale)
    else:
        raise NotImplementedError("Not Implemented Model")

    net = nn.DataParallel(net).to(torch_device)
    if arg.loss == "l2":
        recon_loss = nn.L2Loss()
    elif arg.loss == "l1":
        recon_loss = nn.L1Loss()

    recon_loss = EdgeWeightedLoss(1, 10)
    #recon_loss=nn.BCELoss()

    model = CNNTrainer(arg,
                       net,
                       torch_device,
                       recon_loss=recon_loss,
                       logger=logger)
    model.load(filename="epoch[0123]_losssum[0.005018].pth.tar")
    if arg.test == 0:
        model.train(train_loader, test_loader)
    if arg.test == 1:
Exemple #6
0
def train(
					G_net,
					D_net,
					device,
					epochs=1,
					batch_size=1,
					lr=0.001,
					val_percent=0.1,
					save_cp=True,
					img_scale=0.5):


		superpixel_fn='sscolor'
		SuperPixelDict = {
			'slic': slic,
			'adaptive_slic': adaptive_slic,
			'sscolor': sscolor}
		superpixel_kwarg: dict = {'seg_num': 200}

		train_data=Data_set(dir_img,dir_real, 0.5)
		train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)

		G=generator(n_channels=3, n_classes=3).to(device=device)
		D_blur=discriminator(n_channels=3, n_classes=1).to(device=device)
		D_gray=discriminator(n_channels=3, n_classes=1).to(device=device)

		G_optimizer=torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.99))
		D_optimizer=torch.optim.Adam(itertools.chain(D_blur.parameters(), D_gray.parameters()), lr=2e-4,betas=(0.5, 0.99))

		vgg=VGGCaffePreTrained()

		color_shift=ColorShift()
		color_shift.setup(device=device)
		vgg.setup(device=device)
		
		for parma in vgg.parameters():
				parma.requires_grad = False


		criterion = nn.BCELoss()
		variation_loss=VariationLoss()
		l2_loss=nn.L2Loss('mean')
		lsgan_loss = LSGanLoss()
		superpixel_fn = partial(SuperPixelDict[superpixel_fn],**superpixel_kwarg)
		for epoch in range(epochs):
				G.train()
				D_blur.train()
				D_gray.train()
				num = 0
				for batch in train_loader:
						fake=batch['fake']
						real=batch['real']
						
						num+=1
						size=fake.size(0)
						
						fake=fake.to(device=device, dtype=torch.float32)
						real=real.to(device=device, dtype=torch.float32)
						###part of training disc
						fake_out=G(fake)#.detach()
						fake_output=guide_filter(fake, fake_out, r=1)

						fake_blur=guide_filter(fake_output, fake_output, r=5, eps=2e-1)
						#Part 1. Blur GAN
						real_blur=guide_filter(real, real, r=5,eps=2e-1)

						fake_gray, real_gray=color_shift(fake_output, real)#Part 2.Gray GAN

						fake_disc_blur=D_blur(fake_blur)
						real_disc_blur=D_blur(real_blur)

						fake_disc_gray=D_gray(fake_gray)
						real_disc_gray=D_gray(real_gray)

						loss_blur=lsgan_loss._d_loss(real_disc_blur,fake_disc_blur)
						loss_gray=lsgan_loss._d_loss(real_disc_gray,fake_disc_gray)

						all_loss=loss_blur+loss_gray

						D_optimizer.zero_grad()

						all_loss.backward()

						D_optimizer.step()

						###part of training generator
						fake_out=G(fake)
						fake_output=guide_filter(fake, fake_out, r=1)

						fake_blur=guide_filter(fake_output, fake_output, r=5, eps=2e-1)
						fake_disc_blur=D_blur(fake_blur)
						loss_fake_blur=lsgan_loss._g_loss(fake_disc_blur)

						fake_gray, =color_shift(fake_output)
						fake_disc_gray=D_gray(fake_gray)
						loss_fake_gray=lsgan_loss._g_loss(fake_disc_gray)

						VGG1=vgg(fake)
						VGG2=vgg(fake_output)

						superpixel_img=torch.from_numpy(
												simple_superpixel(fake_output.detach().permute((0, 2, 3, 1)).cpu().numpy(),
												superpixel_fn)
												).to(device).permute((0, 3, 1, 2))

						VGG3=vgg(superpixel_img)
						_, c, h, w=VGG2.shape
						loss_superpixel=l1_loss(VGG3, VGG2)/(c*h*w)

						loss_content=l1_loss(VGG1, VGG2)/(c*h*w)

						loss_tv=variation_loss(fake_output)

						#parameters here
						w1=0.1
						w2=1.0
						w3=200.0
						w4=200.0
						w5=10000.0
						loss_sum=loss_fake_blur*w1+loss_fake_gray*w2+loss_superpixel*w3+loss_content*w4+loss_tv*w5

						G_optimizer.zero_grad()
						loss_sum.backward()
						G_optimizer.step()
		torch.save(G.state_dict(), dir_checkpoint +f'checkpoint_epoch{epochs}.pth')
		torch.save(D_blur.state_dict(), dir_checkpoint +f'checkpoint_epoch{epochs}.pth')
		torch.save(D_gray.state_dict(), dir_checkpoint +f'checkpoint_epoch{epochs}.pth')
Exemple #7
0
    def __init__(self, path_state_dict=''):

        ##import pdb; pdb.set_trace()
        self.writer: Optional[CustomWriter] = None
        meltrans = create_mel_filterbank( hp.sampling_rate, hp.n_fft, fmin=hp.mel_fmin, fmax=hp.mel_fmax, n_mels=hp.mel_freq)

        self.model = melGen(self.writer, hp.n_freq, meltrans, hp.mel_generator)
        count_parameters(self.model)

        self.module = self.model

        self.lws_processor = lws.lws(hp.n_fft, hp.l_hop, mode='speech', perfectrec=False)

        self.prev_stoi_scores = {}
        self.base_stoi_scores = {}

        if hp.crit == "l1":
            self.criterion = nn.L1Loss(reduction='none')
        elif hp.crit == "l2":
            self.criterion = nn.L2Loss(reduction='none')
        else:
            print("Loss not implemented")
            return None

        self.criterion2 = nn.L1Loss(reduction='none')


        self.f_specs=  {0: [(5, 2),(15,5)],
                        1: [(5, 2)],
                        2: [(3 ,1)],
                        3: [(3 ,1),(5, 2 )],
                        4: [(3 ,1),(5, 2 ), ( 7,3 )  ],
                        5: [(15 ,5)],
                        6: [(3 ,1),(5, 2 ), ( 7,3 ), (15,5), (25,10)],
                        7: [(1 ,1)],
                        8: [(1 ,1), (3 ,1), (5, 2 ),(15 ,5),  ( 7,3 ),  (25,10), (9,4), (20,5), (5,3)   ]
                        }[hp.loss_mode]
                        


        self.filters = [gen_filter(k) for  k,s in self.f_specs]

        if hp.optimizer == "adam":
            self.optimizer = Adam(self.model.parameters(),
                                lr=hp.learning_rate,
                                weight_decay=hp.weight_decay,
                                )
        elif hp.optimizer == "sgd":
            self.optimizer = SGD(self.model.parameters(),
                                lr=hp.learning_rate,
                                weight_decay=hp.weight_decay,
                                )
        elif hp.optimizer == "radam":
            self.optimizer = RAdam(self.model.parameters(),
                                lr=hp.learning_rate,
                                weight_decay=hp.weight_decay,
                                )
        elif hp.optimizer == "novograd":
            self.optimizer = NovoGrad(self.model.parameters(), 
                                    lr=hp.learning_rate, 
                                    weight_decay=hp.weight_decay
                                    )
        elif hp.optimizer == "sm3":
            raise NameError('sm3 not implemented')
        else:
            raise NameError('optimizer not implemented')


        self.__init_device(hp.device)

        ##if  hp.optimizer == "novograd":
        ##    self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, 200 ,1e-5)
        ##else:
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, **hp.scheduler)

        self.max_epochs = hp.n_epochs


        self.valid_eval_sample: Dict[str, Any] = dict()

        # len_weight = hp.repeat_train
        # self.loss_weight = torch.tensor(
        #     [1./i for i in range(len_weight, 0, -1)],
        # )
        # self.loss_weight /= self.loss_weight.sum()

        # Load State Dict
        if path_state_dict:
            st_model, st_optim, st_sched = torch.load(path_state_dict, map_location=self.in_device)
            try:
                self.module.load_state_dict(st_model)
                self.optimizer.load_state_dict(st_optim)
                self.scheduler.load_state_dict(st_sched)
            except:
                raise Exception('The model is different from the state dict.')


        path_summary = hp.logdir / 'summary.txt'
        if not path_summary.exists():
            # print_to_file(
            #     path_summary,
            #     summary,
            #     (self.model, hp.dummy_input_size),
            #     dict(device=self.str_device[:4])
            # )
            with path_summary.open('w') as f:
                f.write('\n')
            with (hp.logdir / 'hparams.txt').open('w') as f:
                f.write(repr(hp))