コード例 #1
0
ファイル: loss.py プロジェクト: bcaitech1/p1-img-MaiHon
    def forward(self, inp, tar):
        with autocast():
            log_p = self.ce(inp, tar)
            p = torch.exp(-log_p)

            loss = self.alpha * (1 - p)**self.gamma * log_p
            return loss.mean()
コード例 #2
0
    def run_step(self):
        assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!"
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), f"[{self.__class__.__name__}] CUDA is required for AMP training!"
            from torch.cuda.amp.autocast_mode import autocast

        start = time.perf_counter()
        data = next(self._data_loader_iter)
        data_time = time.perf_counter() - start

        if self.cfg.SOLVER.AMP.ENABLED:
            with autocast():
                loss_dict = self.model(data)
                losses = sum(loss_dict.values())
            self.optimizer.zero_grad()
            self.grad_scaler.scale(losses).backward()

            self._write_metrics(loss_dict, data_time)

            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            loss_dict = self.model(data)
            losses = sum(loss_dict.values())
            self.optimizer.zero_grad()
            losses.backward()

            self._write_metrics(loss_dict, data_time)

            self.optimizer.step()

        if isinstance(self.param_wrapper, ContiguousParams):
            self.param_wrapper.assert_buffer_is_valid()
コード例 #3
0
ファイル: tacotron.py プロジェクト: stjordanis/TTS
    def train_step(self, batch, criterion):
        """Perform a single training step by fetching the right set if samples from the batch.

        Args:
            batch ([type]): [description]
            criterion ([type]): [description]
        """
        text_input = batch["text_input"]
        text_lengths = batch["text_lengths"]
        mel_input = batch["mel_input"]
        mel_lengths = batch["mel_lengths"]
        linear_input = batch["linear_input"]
        stop_targets = batch["stop_targets"]
        stop_target_lengths = batch["stop_target_lengths"]
        speaker_ids = batch["speaker_ids"]
        d_vectors = batch["d_vectors"]

        # forward pass model
        outputs = self.forward(
            text_input,
            text_lengths,
            mel_input,
            mel_lengths,
            aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
        )

        # set the [alignment] lengths wrt reduction factor for guided attention
        if mel_lengths.max() % self.decoder.r != 0:
            alignment_lengths = (
                mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
            ) // self.decoder.r
        else:
            alignment_lengths = mel_lengths // self.decoder.r

        aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
        outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)

        # compute loss
        with autocast(enabled=False):  # use float32 for the criterion
            loss_dict = criterion(
                outputs["model_outputs"].float(),
                outputs["decoder_outputs"].float(),
                mel_input.float(),
                linear_input.float(),
                outputs["stop_tokens"].float(),
                stop_targets.float(),
                stop_target_lengths,
                mel_lengths,
                None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
                outputs["alignments"].float(),
                alignment_lengths,
                None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
                text_lengths,
            )

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(outputs["alignments"])
        loss_dict["align_error"] = align_error
        return outputs, loss_dict
コード例 #4
0
    def train_step(self, batch: Dict, criterion: torch.nn.Module):
        """A single training step. Forward pass and loss computation.

        Args:
            batch ([Dict]): A dictionary of input tensors.
            criterion ([type]): Callable criterion to compute model loss.
        """
        text_input = batch["text_input"]
        text_lengths = batch["text_lengths"]
        mel_input = batch["mel_input"]
        mel_lengths = batch["mel_lengths"]
        stop_targets = batch["stop_targets"]
        stop_target_lengths = batch["stop_target_lengths"]
        speaker_ids = batch["speaker_ids"]
        d_vectors = batch["d_vectors"]

        # forward pass model
        outputs = self.forward(
            text_input,
            text_lengths,
            mel_input,
            mel_lengths,
            aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
        )

        # set the [alignment] lengths wrt reduction factor for guided attention
        if mel_lengths.max() % self.decoder.r != 0:
            alignment_lengths = (
                mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
            ) // self.decoder.r
        else:
            alignment_lengths = mel_lengths // self.decoder.r

        aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
        outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)

        # compute loss
        with autocast(enabled=False):  # use float32 for the criterion
            loss_dict = criterion(
                outputs["model_outputs"].float(),
                outputs["decoder_outputs"].float(),
                mel_input.float(),
                None,
                outputs["stop_tokens"].float(),
                stop_targets.float(),
                stop_target_lengths,
                mel_lengths,
                None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
                outputs["alignments"].float(),
                alignment_lengths,
                None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
                text_lengths,
            )

        # compute alignment error (the lower the better )
        align_error = 1 - alignment_diagonal_score(outputs["alignments"])
        loss_dict["align_error"] = align_error
        return outputs, loss_dict
コード例 #5
0
    def train_step(self, batch: dict, criterion: nn.Module):
        """A single training step. Forward pass and loss computation. Run data depended initialization for the
        first `config.data_dep_init_steps` steps.

        Args:
            batch (dict): [description]
            criterion (nn.Module): [description]
        """
        text_input = batch["text_input"]
        text_lengths = batch["text_lengths"]
        mel_input = batch["mel_input"]
        mel_lengths = batch["mel_lengths"]
        d_vectors = batch["d_vectors"]
        speaker_ids = batch["speaker_ids"]

        if self.run_data_dep_init and self.training:
            # compute data-dependent initialization of activation norm layers
            self.unlock_act_norm_layers()
            with torch.no_grad():
                _ = self.forward(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    aux_input={
                        "d_vectors": d_vectors,
                        "speaker_ids": speaker_ids
                    },
                )
            outputs = None
            loss_dict = None
            self.lock_act_norm_layers()
        else:
            # normal training step
            outputs = self.forward(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                aux_input={
                    "d_vectors": d_vectors,
                    "speaker_ids": speaker_ids
                },
            )

            with autocast(enabled=False):  # avoid mixed_precision in criterion
                loss_dict = criterion(
                    outputs["z"].float(),
                    outputs["y_mean"].float(),
                    outputs["y_log_scale"].float(),
                    outputs["logdet"].float(),
                    mel_lengths,
                    outputs["durations_log"].float(),
                    outputs["total_durations_log"].float(),
                    text_lengths,
                )
        return outputs, loss_dict
コード例 #6
0
ファイル: trainer.py プロジェクト: bcaitech1/p1-img-MaiHon
 def step(self, sample, scaler=None, valid=False):
     images = sample['image'].to(self.cfg.device)
     labels = sample['label'].to(self.cfg.device)
     
     if scaler is not None:
         with autocast():
             logits = self.model(images)
             if not valid:
                 loss = self.trn_crit(logits, labels)
             else:
                 loss = self.val_crit(logits, labels)
         
         if not valid:                
             scaler.scale(loss).backward()
             # clipping point -> batchnorm을 대체하는 역할 AGC
             scaler.unscale_(self.optim)
             if self.cfg.clipping:
                 timm.utils.adaptive_clip_grad(self.model.parameters())
             scaler.step(self.optim)
             scaler.update()
     else:
         logits = self.model(images)
         if not valid:
             loss = self.trn_crit(logits, labels)
             
             self.optim.zero_grad()
             loss.backward()
             
             if self.cfg.clipping:
                 timm.utils.adaptive_clip_grad(self.model.parameters())
             
             self.optim.step()
         else:
             loss = self.val_crit(logits, labels)
         
         
         if self.cfg.nosiy_elimination:
             logit_preds = -F.log_softmax(logits, dim=-1)
             indexs = sample['idx'].detach().cpu().numpy()
             self.prediction_by_idx[indexs][:, :-1] += self.prediction_by_idx[indexs][:, :-1] * 0.2
             self.prediction_by_idx[indexs][:, :-1] += logit_preds
             self.prediction_by_idx[indexs][:, -1] = labels.detach().cpu().numpy()
         
     batch_acc = self.accuracy(logits, labels)
     batch_f1  = self.f1_score(logits, labels)
     result = {
         'logit': logits,
         'loss': loss,
         'batch_acc': batch_acc,
         'batch_f1' : batch_f1
     }
     return result        
コード例 #7
0
ファイル: forward_tts.py プロジェクト: gerazov/TTS
    def train_step(self, batch: dict, criterion: nn.Module):
        text_input = batch["text_input"]
        text_lengths = batch["text_lengths"]
        mel_input = batch["mel_input"]
        mel_lengths = batch["mel_lengths"]
        pitch = batch["pitch"] if self.args.use_pitch else None
        d_vectors = batch["d_vectors"]
        speaker_ids = batch["speaker_ids"]
        durations = batch["durations"]
        aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}

        # forward pass
        outputs = self.forward(text_input,
                               text_lengths,
                               mel_lengths,
                               y=mel_input,
                               dr=durations,
                               pitch=pitch,
                               aux_input=aux_input)
        # use aligner's output as the duration target
        if self.use_aligner:
            durations = outputs["o_alignment_dur"]
        # use float32 in AMP
        with autocast(enabled=False):
            # compute loss
            loss_dict = criterion(
                decoder_output=outputs["model_outputs"],
                decoder_target=mel_input,
                decoder_output_lens=mel_lengths,
                dur_output=outputs["durations_log"],
                dur_target=durations,
                pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
                pitch_target=outputs["pitch_avg_gt"]
                if self.use_pitch else None,
                input_lens=text_lengths,
                alignment_logprob=outputs["alignment_logprob"]
                if self.use_aligner else None,
                alignment_soft=outputs["alignment_soft"]
                if self.use_binary_alignment_loss else None,
                alignment_hard=outputs["alignment_mas"]
                if self.use_binary_alignment_loss else None,
            )
            # compute duration error
            durations_pred = outputs["durations"]
            duration_error = torch.abs(
                durations - durations_pred).sum() / text_lengths.sum()
            loss_dict["duration_error"] = duration_error

        return outputs, loss_dict
コード例 #8
0
ファイル: imix_engine.py プロジェクト: inspur-hsslab/iMIX
    def run_train_iter(self, batch_data=None, data_time=None):
        assert self.model.training, '[CommonEngine] model was changed to eval model!'

        if self.batch_processor is not None:
            self.output = self.batch_processor(batch_data)
        else:
            with autocast(enabled=is_mixed_precision()):
                self.model_output = self.model(
                    batch_data,
                    cur_epoch=getattr(self, 'epoch', None),
                    cur_iter=self.iter,
                    inner_iter=getattr(self, 'inner_iter', None))
                self.output = self.loss_fn(self.model_output)

        metrics_dict = {'data_time': data_time}
        metrics_dict.update(self.output)
        write_metrics(metrics_dict)
コード例 #9
0
    def forward(self, z_obj, z_cam_mid, z_obj_mid, camera):
        with autocast(enabled=self.training):
            num_views = z_obj.shape[1]

            h = z_obj[:, 0]
            if self.conv_module == EqualizedConv2d:
                # Concatenate pixel coords if 2d.
                coords = utils.get_normalized_pixel_coords(h)
            else:
                coords = utils.get_normalized_voxel_coords(h)

            for i in range(1, num_views):
                x = torch.cat((z_obj[:, i], coords), dim=1)
                h = self.gru(x, h)

            h = h.unsqueeze(1)

            return h, {}
コード例 #10
0
ファイル: loss.py プロジェクト: bcaitech1/p1-img-MaiHon
    def forward(self, cosine, label):
        with autocast():
            # --------------------------- cos(theta) & phi(theta) ---------------------------
            sine = torch.sqrt((torch.sub(1.0, cosine * cosine)).clamp(0, 1))
            phi = torch.mul(cosine, self.cos_m) - torch.mul(sine, self.sin_m)
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
            # --------------------------- convert label to one-hot ---------------------------
            one_hot = torch.zeros(cosine.size(), device='cuda')
            one_hot.scatter_(1, label.view(-1, 1).long(), 1)
            # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
            output = (one_hot * phi) + (
                (1.0 - one_hot) * cosine
            )  # you can use torch.where if your torch.__version__ is 0.4
            output *= self.s

            loss = self.crit(output, label)
            if self.reduction == "mean": loss = loss.mean()
            elif self.reduction == "sum": loss = loss.sum()

            return loss
コード例 #11
0
ファイル: uda_base.py プロジェクト: X-funbean/fast-reid
    def run_step(self):
        assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!"
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), f"[{self.__class__.__name__}] CUDA is required for AMP training!"
            from torch.cuda.amp.autocast_mode import autocast

        start = time.perf_counter()

        # load data
        tgt_inputs = self.pseudo_tgt_train_loader.next()

        def _parse_data(inputs):
            imgs, _, pids, _ = inputs
            return imgs.cuda(), pids.cuda()

        # process inputs
        t_inputs, t_targets = _parse_data(tgt_inputs)

        data_time = time.perf_counter() - start

        def _forward():
            outputs = self.model(t_inputs)
            f_out_t = outputs['features']
            p_out_t = outputs['pred_class_logits'][:, :self.num_clusters]

            loss_dict = {}

            loss_ce = cross_entropy_loss(pred_class_outputs=p_out_t,
                                         gt_classes=t_targets,
                                         eps=self.cfg.MODEL.LOSSES.CE.EPSILON,
                                         alpha=self.cfg.MODEL.LOSSES.CE.ALPHA)
            loss_dict.update({'loss_ce': loss_ce})

            if 'TripletLoss' in self.cfg.MODEL.LOSSES.NAME:
                loss_tri = triplet_loss(f_out_t,
                                        t_targets,
                                        margin=0.0,
                                        norm_feat=True,
                                        hard_mining=False)
                loss_dict.update({'loss_tri': loss_tri})

            return loss_dict

        if self.cfg.SOLVER.AMP.ENABLED:
            with autocast():
                loss_dict = _forward()
                losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            self.grad_scaler.scale(losses).backward()

            self._write_metrics(loss_dict, data_time)

            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            loss_dict = _forward()
            losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            losses.backward()

            self._write_metrics(loss_dict, data_time)

            self.optimizer.step()

        if isinstance(self.param_wrapper, ContiguousParams):
            self.param_wrapper.assert_buffer_is_valid()
コード例 #12
0
    def train_step(self, batch: dict, criterion: nn.Module,
                   optimizer_idx: int) -> Tuple[Dict, Dict]:
        """Perform a single training step. Run the model forward pass and compute losses.

        Args:
            batch (Dict): Input tensors.
            criterion (nn.Module): Loss layer designed for the model.
            optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.

        Returns:
            Tuple[Dict, Dict]: Model ouputs and computed losses.
        """
        # pylint: disable=attribute-defined-outside-init
        if optimizer_idx not in [0, 1]:
            raise ValueError(" [!] Unexpected `optimizer_idx`.")

        if self.args.freeze_encoder:
            for param in self.text_encoder.parameters():
                param.requires_grad = False

            if hasattr(self, "emb_l"):
                for param in self.emb_l.parameters():
                    param.requires_grad = False

        if self.args.freeze_PE:
            for param in self.posterior_encoder.parameters():
                param.requires_grad = False

        if self.args.freeze_DP:
            for param in self.duration_predictor.parameters():
                param.requires_grad = False

        if self.args.freeze_flow_decoder:
            for param in self.flow.parameters():
                param.requires_grad = False

        if self.args.freeze_waveform_decoder:
            for param in self.waveform_decoder.parameters():
                param.requires_grad = False

        if optimizer_idx == 0:
            text_input = batch["text_input"]
            text_lengths = batch["text_lengths"]
            mel_lengths = batch["mel_lengths"]
            linear_input = batch["linear_input"]
            d_vectors = batch["d_vectors"]
            speaker_ids = batch["speaker_ids"]
            language_ids = batch["language_ids"]
            waveform = batch["waveform"]

            # generator pass
            outputs = self.forward(
                text_input,
                text_lengths,
                linear_input.transpose(1, 2),
                mel_lengths,
                waveform.transpose(1, 2),
                aux_input={
                    "d_vectors": d_vectors,
                    "speaker_ids": speaker_ids,
                    "language_ids": language_ids
                },
            )

            # cache tensors for the discriminator
            self.y_disc_cache = None
            self.wav_seg_disc_cache = None
            self.y_disc_cache = outputs["model_outputs"]
            self.wav_seg_disc_cache = outputs["waveform_seg"]

            # compute discriminator scores and features
            outputs["scores_disc_fake"], outputs[
                "feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
                    outputs["model_outputs"], outputs["waveform_seg"])

            # compute losses
            with autocast(enabled=False):  # use float32 for the criterion
                loss_dict = criterion[optimizer_idx](
                    waveform_hat=outputs["model_outputs"].float(),
                    waveform=outputs["waveform_seg"].float(),
                    z_p=outputs["z_p"].float(),
                    logs_q=outputs["logs_q"].float(),
                    m_p=outputs["m_p"].float(),
                    logs_p=outputs["logs_p"].float(),
                    z_len=mel_lengths,
                    scores_disc_fake=outputs["scores_disc_fake"],
                    feats_disc_fake=outputs["feats_disc_fake"],
                    feats_disc_real=outputs["feats_disc_real"],
                    loss_duration=outputs["loss_duration"],
                    use_speaker_encoder_as_loss=self.args.
                    use_speaker_encoder_as_loss,
                    gt_spk_emb=outputs["gt_spk_emb"],
                    syn_spk_emb=outputs["syn_spk_emb"],
                )

        elif optimizer_idx == 1:
            # discriminator pass
            outputs = {}

            # compute scores and features
            outputs["scores_disc_fake"], _, outputs[
                "scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(),
                                                   self.wav_seg_disc_cache)

            # compute loss
            with autocast(enabled=False):  # use float32 for the criterion
                loss_dict = criterion[optimizer_idx](
                    outputs["scores_disc_real"],
                    outputs["scores_disc_fake"],
                )
        return outputs, loss_dict
コード例 #13
0
    def run_step(self):
        assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!"
        if self.cfg.SOLVER.AMP.ENABLED:
            assert torch.cuda.is_available(
            ), f"[{self.__class__.__name__}] CUDA is required for AMP training!"
            from torch.cuda.amp.autocast_mode import autocast

        start = time.perf_counter()

        # load data
        src_inputs = self.src_train_loader.next()
        tgt_inputs = self.pseudo_tgt_train_loader.next()

        # src_inputs = next(self.src_load_iter)
        # tgt_inputs = next(self.tgt_load_iter)

        def _parse_data(inputs):
            # print(len(inputs))
            # for i in range(len(inputs)):
            #     print(i, type(inputs[i]), inputs[i])
            imgs, _, pids, _, indices = inputs
            return imgs.cuda(), pids.cuda(), indices

        # process inputs
        s_inputs, s_targets, s_indices = _parse_data(src_inputs)
        t_inputs, t_targets, t_indices = _parse_data(tgt_inputs)
        # print('src', s_targets, s_indices)
        # print('tgt', t_targets, t_indices)
        # exit()

        # arrange batch for domain-specific BNP
        device_num = torch.cuda.device_count()
        B, C, H, W = s_inputs.size()

        def reshape(inputs):
            return inputs.view(device_num, -1, C, H, W)

        s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs)
        inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W)

        data_time = time.perf_counter() - start

        def _forward():
            outputs = self.model(inputs)
            if isinstance(outputs, dict):
                f_out = outputs['features']
            else:
                f_out = outputs

            # de-arrange batch
            f_out = f_out.view(device_num, -1, f_out.size(-1))

            f_out_s, f_out_t = f_out.split(f_out.size(1) // 2, dim=1)
            f_out_s, f_out_t = f_out_s.contiguous().view(
                -1,
                f_out.size(-1)), f_out_t.contiguous().view(-1, f_out.size(-1))

            # compute loss with the hybrid memory
            # with autocast(enabled=False):
            loss_s = self.hm(f_out_s, s_targets)
            loss_t = self.hm(f_out_t, t_indices + self.src_pid_nums)

            loss_dict = {'loss_s': loss_s, 'loss_t': loss_t}
            return loss_dict

        if self.cfg.SOLVER.AMP.ENABLED:
            with autocast():
                loss_dict = _forward()
                losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            self.grad_scaler.scale(losses).backward()

            self._write_metrics(loss_dict, data_time)

            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            loss_dict = _forward()
            losses = sum(loss_dict.values())

            self.optimizer.zero_grad()
            losses.backward()

            self._write_metrics(loss_dict, data_time)

            self.optimizer.step()

        if isinstance(self.param_wrapper, ContiguousParams):
            self.param_wrapper.assert_buffer_is_valid()
コード例 #14
0
 def autocast(self):
     return autocast(enabled=self.enabled)
コード例 #15
0
def train(train_loader, model, criterion, optimizer, epoch, args, scaler=None):
    # train for one epoch
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    psnr_out = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    if args.lr_policy == 'naive':
        local_lr = adjust_learning_rate_naive(optimizer, epoch, args)
    elif args.lr_policy == 'step':
        local_lr = adjust_learning_rate(optimizer, epoch, args)
    elif args.lr_policy == 'epoch_poly':
        local_lr = adjust_learning_rate_epoch_poly(optimizer, epoch, args)

    for i, (target, input_group) in enumerate(train_loader):

        # set random task
        task_id = random.randint(0,
                                 5) if not args.task else task_map[args.task]
        input = input_group[task_id]
        model.module.set_task(task_id)
        #print(f"Iter {i}, task_id: {task_id}")
        #for m in model.module.modules():
        # if isinstance(m, )
        #print(m.weight.device)
        global_iter = epoch * args.epoch_size + i

        if args.lr_policy == 'iter_poly':
            local_lr = adjust_learning_rate_poly(optimizer, global_iter, args)
        elif args.lr_policy == 'cosine':
            local_lr = adjust_learning_rate_cosine(optimizer, global_iter,
                                                   args)

        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

        target = target.cuda()
        if scaler is None:
            # compute output
            output = model(input)
            #print(output.device, target.device)
            loss = criterion(output, target)
        else:
            with autocast():
                # compute output
                output = model(input)
                #print(output.device, target.device)
                loss = criterion(output, target)

        # measure accuracy and record loss
        output = (output * 0.5 + 0.5) * 255.
        target = (target * 0.5 + 0.5) * 255.
        psnr = PSNR()(output, target)
        losses.update(loss.item(), input.size(0))
        psnr_out.update(psnr.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()

        if scaler is None:
            # compute gradient and do SGD step
            loss.backward()
            optimizer.step()
        else:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'PSNR {psnr.val:.3f} ({psnr.avg:.3f})\t'
                  'LR: {lr: .6f}'.format(epoch,
                                         i,
                                         args.epoch_size,
                                         batch_time=batch_time,
                                         data_time=data_time,
                                         loss=losses,
                                         psnr=psnr_out,
                                         lr=local_lr))
コード例 #16
0
ファイル: vits.py プロジェクト: synesthesiam/opentts
    def train_step(self, batch: dict, criterion: nn.Module,
                   optimizer_idx: int) -> Tuple[Dict, Dict]:
        """Perform a single training step. Run the model forward pass and compute losses.

        Args:
            batch (Dict): Input tensors.
            criterion (nn.Module): Loss layer designed for the model.
            optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.

        Returns:
            Tuple[Dict, Dict]: Model ouputs and computed losses.
        """
        # pylint: disable=attribute-defined-outside-init
        if optimizer_idx not in [0, 1]:
            raise ValueError(" [!] Unexpected `optimizer_idx`.")

        if optimizer_idx == 0:
            text_input = batch["text_input"]
            text_lengths = batch["text_lengths"]
            mel_lengths = batch["mel_lengths"]
            linear_input = batch["linear_input"]
            d_vectors = batch["d_vectors"]
            speaker_ids = batch["speaker_ids"]
            waveform = batch["waveform"]

            # generator pass
            outputs = self.forward(
                text_input,
                text_lengths,
                linear_input.transpose(1, 2),
                mel_lengths,
                aux_input={
                    "d_vectors": d_vectors,
                    "speaker_ids": speaker_ids
                },
            )

            # cache tensors for the discriminator
            self.y_disc_cache = None
            self.wav_seg_disc_cache = None
            self.y_disc_cache = outputs["model_outputs"]
            wav_seg = segment(
                waveform.transpose(1, 2),
                outputs["slice_ids"] * self.config.audio.hop_length,
                self.args.spec_segment_size * self.config.audio.hop_length,
            )
            self.wav_seg_disc_cache = wav_seg
            outputs["waveform_seg"] = wav_seg

            # compute discriminator scores and features
            (
                outputs["scores_disc_fake"],
                outputs["feats_disc_fake"],
                _,
                outputs["feats_disc_real"],
            ) = self.disc(outputs["model_outputs"], wav_seg)

            # compute losses
            with autocast(enabled=False):  # use float32 for the criterion
                loss_dict = criterion[optimizer_idx](
                    waveform_hat=outputs["model_outputs"].float(),
                    waveform=wav_seg.float(),
                    z_p=outputs["z_p"].float(),
                    logs_q=outputs["logs_q"].float(),
                    m_p=outputs["m_p"].float(),
                    logs_p=outputs["logs_p"].float(),
                    z_len=mel_lengths,
                    scores_disc_fake=outputs["scores_disc_fake"],
                    feats_disc_fake=outputs["feats_disc_fake"],
                    feats_disc_real=outputs["feats_disc_real"],
                    loss_duration=outputs["loss_duration"],
                )

        elif optimizer_idx == 1:
            # discriminator pass
            outputs = {}

            # compute scores and features
            outputs["scores_disc_fake"], _, outputs[
                "scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(),
                                                   self.wav_seg_disc_cache)

            # compute loss
            with autocast(enabled=False):  # use float32 for the criterion
                loss_dict = criterion[optimizer_idx](
                    outputs["scores_disc_real"],
                    outputs["scores_disc_fake"],
                )
        return outputs, loss_dict
コード例 #17
0
    def run_iteration(self, batch, train, is_step):
        self.mark_time()
        # Update depth criterion k if applicable.
        if 'hard_' in self.g_depth_recon_loss_type:
            self._g_depth_recon_criterion.k = int(
                self._g_depth_recon_k_scheduler.get(self.epoch))

        batch = process_batch(batch, self.cube_size, self.camera_dist,
                              self._sculptor.in_size, self.device,
                              self.random_orientation)

        if self.reconstruct_input:
            recon_camera = Camera.vcat(
                (batch['in_gt']['camera'], batch['out_gt']['camera']),
                batch_size=self.batch_size)
            recon_mask = torch.cat(
                (batch['in_gt']['mask'], batch['out_gt']['mask']), dim=1)
            recon_image = torch.cat(
                (batch['in_gt']['image'], batch['out_gt']['image']), dim=1)
            recon_depth = torch.cat(
                (batch['in_gt']['depth'], batch['out_gt']['depth']), dim=1)
        else:
            recon_camera = batch['out_gt']['camera']
            recon_mask = batch['out_gt']['mask']
            recon_image = batch['out_gt']['image']
            recon_depth = batch['out_gt']['depth']

        if not self.color_random_background or self.crop_random_background:
            batch['in']['image'] = batch['in']['image'] * batch['in']['mask']

        if not self.depth_random_background or self.crop_random_background:
            batch['in']['depth'] = mask_normalized_depth(
                batch['in']['depth'], batch['in']['mask'])

        depth_in = None
        if self.generator_input_depth:
            depth_noise = self._depth_noise_dist.sample(
                batch['in']['depth'].size()).to(self.device)
            depth_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1)

        data_process_time = self.mark_time()

        with autocast():
            # Evaluate generator.
            z_obj, z_extra = self._sculptor.encode(
                self._fuser,
                camera=batch['in']['camera'],
                color=batch['in']['image'],
                depth=depth_in,
                mask=batch['in']['mask'],
                data_parallel=self.data_parallel)
            fake_image, fake_depth, fake_mask, fake_mask_logits, fake_vox_depth = \
                self._run_photographer(z_obj, recon_camera, recon_mask)

            if 'blend_weights' in z_extra:
                z_weights = z_extra['blend_weights']
            else:
                z_weights = None

            # Train discriminator.
            if self._discriminator:
                d_real, d_fake_d, d_fake_g = self._run_discriminator(
                    fake_image, fake_depth, fake_mask, recon_image,
                    recon_depth, recon_mask)
                loss_d_real = multiscale_lsgan_loss(d_real, 1)
                loss_d_fake = multiscale_lsgan_loss(d_fake_d, 0)
                loss_d = loss_d_real + loss_d_fake
                loss_g_gan = multiscale_lsgan_loss(d_fake_g, 1)

                if train:
                    loss_d.backward()
                    if is_step:
                        self._optimizers['discriminator'].step()

                self.plotter.put_scalar('loss/discriminator/real', loss_d_real)
                self.plotter.put_scalar('loss/discriminator/fake', loss_d_fake)
                self.plotter.put_scalar('loss/discriminator/total', loss_d)
            else:
                loss_g_gan = torch.tensor(0.0, device=self.device)

            # Train generator.
            if self.predict_color:
                loss_g_color_recon = reduce_loss(
                    self._g_color_recon_criterion(fake_image, recon_image))
            else:
                loss_g_color_recon = torch.tensor(0.0, device=self.device)

            if self.predict_depth or self.use_occlusion_depth:
                loss_g_depth_recon = reduce_loss(
                    self._g_depth_recon_criterion(fake_depth, recon_depth))
            else:
                loss_g_depth_recon = torch.tensor(0.0, device=self.device)

            if self.predict_mask:
                if self.g_mask_recon_loss_type == 'binary_cross_entropy':
                    y_mask = fake_mask_logits
                else:
                    y_mask = fake_mask
                loss_g_mask_recon = reduce_loss(
                    self._g_mask_recon_criterion(y_mask, recon_mask))
                loss_g_mask_beta = beta_prior_loss(
                    fake_mask,
                    alpha=self.g_mask_beta_loss_param,
                    beta=self.g_mask_beta_loss_param)
            else:
                loss_g_mask_recon = torch.tensor(0.0, device=self.device)
                loss_g_mask_beta = torch.tensor(0.0, device=self.device)

            loss_g = (self.g_gan_loss_weight * loss_g_gan +
                      self.g_color_recon_loss_weight * loss_g_color_recon +
                      self.g_depth_recon_loss_weight * loss_g_depth_recon +
                      self.g_mask_recon_loss_weight * loss_g_mask_recon +
                      self.g_mask_beta_loss_weight *
                      loss_g_mask_beta) / self.batch_groups

        if train:
            if self.kwargs.get('use_amp', False):
                self._scaler.scale(loss_g).backward()
            else:
                loss_g.backward()

            if is_step:
                if self.kwargs.get('use_amp', False):
                    self._scaler.step(self._optimizers['generator'])
                    self._scaler.update()
                else:
                    self._optimizers['generator'].step()

        with torch.no_grad():
            if self.predict_depth:
                self.plotter.put_scalar('error/depth/l1',
                                        F.l1_loss(fake_depth, recon_depth))
            if self.reconstruct_input:
                self.plotter.put_scalar(
                    'error/depth/input_l1',
                    F.l1_loss(fake_depth[:, :self.num_input_views],
                              batch['in_gt']['depth']))
                self.plotter.put_scalar(
                    'error/depth/output_l1',
                    F.l1_loss(fake_depth[:, self.num_input_views:],
                              batch['out_gt']['depth']))
            if self.predict_mask:
                self.plotter.put_scalar(
                    'error/mask/cross_entropy',
                    F.binary_cross_entropy_with_logits(fake_mask_logits,
                                                       recon_mask))
                self.plotter.put_scalar('error/mask/l1',
                                        F.l1_loss(fake_mask, recon_mask))

        compute_time = self.mark_time()

        self.plotter.put_scalar('loss/generator/gan', loss_g_gan)
        self.plotter.put_scalar('loss/generator/recon/color',
                                loss_g_color_recon)
        self.plotter.put_scalar('loss/generator/recon/depth',
                                loss_g_depth_recon)
        self.plotter.put_scalar('loss/generator/recon/mask', loss_g_mask_recon)
        self.plotter.put_scalar('loss/generator/recon/mask_beta',
                                loss_g_mask_beta)
        self.plotter.put_scalar('loss/generator/total', loss_g)

        self.plotter.put_scalar('params/input_noise_weight',
                                self.input_noise_weight)
        if hasattr(self._g_depth_recon_criterion, 'k'):
            self.plotter.put_scalar('params/depth_loss_k',
                                    self._g_depth_recon_criterion.k)

        self.plotter.put_scalar('time/data_process', data_process_time)
        self.plotter.put_scalar('time/compute', compute_time)
        plot_scalar_time = self.mark_time()
        self.plotter.put_scalar('time/plot/scalars', plot_scalar_time)

        if self.plotter.is_it_time_yet('histogram'):
            if self.predict_color:
                self.plotter.put_histogram('image_fake', fake_image)
                self.plotter.put_histogram('image_real', recon_image)
            if self.predict_mask:
                self.plotter.put_histogram('mask_fake', fake_mask)
            self.plotter.put_histogram('z_obj', z_obj)
            if z_weights is not None:
                self.plotter.put_histogram('z_weights', z_weights)
        plot_histogram_time = self.mark_time()
        self.plotter.put_scalar('time/plot/histogram', plot_histogram_time)

        if self.plotter.is_it_time_yet('show'):
            self.plotter.put_image(
                'inputs',
                viz.make_grid([
                    gan_denormalize(batch['in']['image']),
                    viz.colorize_depth(batch['in']['depth'])
                    if self.generator_input_depth else None,
                    viz.colorize_tensor(batch['in']['mask'])
                    if self.generator_input_mask else None,
                ],
                              row_size=4,
                              stride=2,
                              output_size=64))
            with torch.no_grad():
                self.plotter.put_image(
                    'reconstruction',
                    viz.make_grid([
                        gan_denormalize(recon_image),
                        gan_denormalize(fake_image) if
                        (fake_image is not None) else None,
                        viz.colorize_depth(recon_depth),
                        viz.colorize_depth(fake_depth) if
                        (fake_depth is not None) else None,
                        viz.colorize_tensor(
                            (recon_depth.cpu() - fake_depth.cpu()).abs()) if
                        (fake_depth is not None) else None,
                        viz.colorize_tensor(recon_mask),
                        viz.colorize_tensor(fake_mask) if
                        (fake_mask is not None) else None,
                        viz.colorize_tensor(
                            (recon_mask.cpu() - fake_mask.cpu()).abs()) if
                        (fake_mask is not None) else None,
                    ],
                                  stride=8))
        plot_images_time = self.mark_time()
        self.plotter.put_scalar('time/plot/images', plot_images_time)
コード例 #18
0
    def trn_step(self, epoch, sample_l, sample_u, scaler=None):
        self.optim.zero_grad()

        images_l = sample_l['image'].to(self.cfg.device)
        images_o = sample_u["image_ori"].to(self.cfg.device)
        images_a = sample_u["image_aug"].to(self.cfg.device)
        labels = sample_l['label'].to(self.cfg.device)

        batch_s = images_l.size(0)
        images_t = torch.cat([images_l, images_o, images_a])
        if scaler is not None:
            with autocast():
                logits_t = self.model(images_t)

                logits_l = logits_t[:batch_s]
                logits_o, logits_a = logits_t[batch_s:].chunk(2)
                del logits_t

                preds_o = F.softmax(logits_o, dim=-1).detach()
                preds_a = F.log_softmax(logits_a, dim=-1)
                kl_loss = F.kl_div(preds_a, preds_o, reduction='none')
                kl_loss = torch.mean(torch.sum(kl_loss, dim=-1))

                l_loss = self.trn_crit(logits_l, labels)

                if self.cfg.ratio_mode == 'constant':
                    t_loss = l_loss + self.cfg.ratio * torch.mean(kl_loss)
                elif self.cfg.ratio_mode == "gradual":
                    t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * torch.mean(
                        kl_loss) + l_loss

            scaler.scale(t_loss).backward()
            # clipping point -> batchnorm을 대체하는 역할 AGC
            scaler.unscale_(self.optim)
            if self.cfg.clipping:
                timm.utils.adaptive_clip_grad(self.model.parameters())
            scaler.step(self.optim)
            scaler.update()
        else:
            logits_t = self.model(images_t)
            logits_l = logits_t[:batch_s]
            logits_o, logits_a = logits_t[batch_s:].chunk(2)
            del logits_t

            preds_o = F.softmax(logits_o, dim=-1).detach()
            preds_a = F.log_softmax(logits_a, dim=-1)
            kl_loss = F.kl_div(preds_a, preds_o, reduction='none')
            kl_loss = torch.mean(torch.sum(kl_loss, dim=-1))

            l_loss = self.trn_crit(logits_l, labels)

            if self.cfg.ratio_mode == 'constant':
                t_loss = l_loss + self.cfg.ratio * kl_loss
            elif self.cfg.ratio_mode == "gradual":
                t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * kl_loss + l_loss

            t_loss.backward()

            if self.cfg.clipping:
                timm.utils.adaptive_clip_grad(self.model.parameters())

            self.optim.step()

        batch_acc = self.accuracy(logits_l, labels)
        batch_f1 = self.f1_score(logits_l, labels)
        result = {
            'l_loss': l_loss,
            't_loss': t_loss,
            'kl_loss': kl_loss,
            'batch_acc': batch_acc,
            'batch_f1': batch_f1
        }
        return result