def forward(self, inp, tar): with autocast(): log_p = self.ce(inp, tar) p = torch.exp(-log_p) loss = self.alpha * (1 - p)**self.gamma * log_p return loss.mean()
def run_step(self): assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!" if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), f"[{self.__class__.__name__}] CUDA is required for AMP training!" from torch.cuda.amp.autocast_mode import autocast start = time.perf_counter() data = next(self._data_loader_iter) data_time = time.perf_counter() - start if self.cfg.SOLVER.AMP.ENABLED: with autocast(): loss_dict = self.model(data) losses = sum(loss_dict.values()) self.optimizer.zero_grad() self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: loss_dict = self.model(data) losses = sum(loss_dict.values()) self.optimizer.zero_grad() losses.backward() self._write_metrics(loss_dict, data_time) self.optimizer.step() if isinstance(self.param_wrapper, ContiguousParams): self.param_wrapper.assert_buffer_is_valid()
def train_step(self, batch, criterion): """Perform a single training step by fetching the right set if samples from the batch. Args: batch ([type]): [description] criterion ([type]): [description] """ text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] linear_input = batch["linear_input"] stop_targets = batch["stop_targets"] stop_target_lengths = batch["stop_target_lengths"] speaker_ids = batch["speaker_ids"] d_vectors = batch["d_vectors"] # forward pass model outputs = self.forward( text_input, text_lengths, mel_input, mel_lengths, aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors}, ) # set the [alignment] lengths wrt reduction factor for guided attention if mel_lengths.max() % self.decoder.r != 0: alignment_lengths = ( mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) ) // self.decoder.r else: alignment_lengths = mel_lengths // self.decoder.r aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion( outputs["model_outputs"].float(), outputs["decoder_outputs"].float(), mel_input.float(), linear_input.float(), outputs["stop_tokens"].float(), stop_targets.float(), stop_target_lengths, mel_lengths, None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), outputs["alignments"].float(), alignment_lengths, None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), text_lengths, ) # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(outputs["alignments"]) loss_dict["align_error"] = align_error return outputs, loss_dict
def train_step(self, batch: Dict, criterion: torch.nn.Module): """A single training step. Forward pass and loss computation. Args: batch ([Dict]): A dictionary of input tensors. criterion ([type]): Callable criterion to compute model loss. """ text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] stop_targets = batch["stop_targets"] stop_target_lengths = batch["stop_target_lengths"] speaker_ids = batch["speaker_ids"] d_vectors = batch["d_vectors"] # forward pass model outputs = self.forward( text_input, text_lengths, mel_input, mel_lengths, aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors}, ) # set the [alignment] lengths wrt reduction factor for guided attention if mel_lengths.max() % self.decoder.r != 0: alignment_lengths = ( mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) ) // self.decoder.r else: alignment_lengths = mel_lengths // self.decoder.r aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion( outputs["model_outputs"].float(), outputs["decoder_outputs"].float(), mel_input.float(), None, outputs["stop_tokens"].float(), stop_targets.float(), stop_target_lengths, mel_lengths, None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), outputs["alignments"].float(), alignment_lengths, None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), text_lengths, ) # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(outputs["alignments"]) loss_dict["align_error"] = align_error return outputs, loss_dict
def train_step(self, batch: dict, criterion: nn.Module): """A single training step. Forward pass and loss computation. Run data depended initialization for the first `config.data_dep_init_steps` steps. Args: batch (dict): [description] criterion (nn.Module): [description] """ text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] if self.run_data_dep_init and self.training: # compute data-dependent initialization of activation norm layers self.unlock_act_norm_layers() with torch.no_grad(): _ = self.forward( text_input, text_lengths, mel_input, mel_lengths, aux_input={ "d_vectors": d_vectors, "speaker_ids": speaker_ids }, ) outputs = None loss_dict = None self.lock_act_norm_layers() else: # normal training step outputs = self.forward( text_input, text_lengths, mel_input, mel_lengths, aux_input={ "d_vectors": d_vectors, "speaker_ids": speaker_ids }, ) with autocast(enabled=False): # avoid mixed_precision in criterion loss_dict = criterion( outputs["z"].float(), outputs["y_mean"].float(), outputs["y_log_scale"].float(), outputs["logdet"].float(), mel_lengths, outputs["durations_log"].float(), outputs["total_durations_log"].float(), text_lengths, ) return outputs, loss_dict
def step(self, sample, scaler=None, valid=False): images = sample['image'].to(self.cfg.device) labels = sample['label'].to(self.cfg.device) if scaler is not None: with autocast(): logits = self.model(images) if not valid: loss = self.trn_crit(logits, labels) else: loss = self.val_crit(logits, labels) if not valid: scaler.scale(loss).backward() # clipping point -> batchnorm을 대체하는 역할 AGC scaler.unscale_(self.optim) if self.cfg.clipping: timm.utils.adaptive_clip_grad(self.model.parameters()) scaler.step(self.optim) scaler.update() else: logits = self.model(images) if not valid: loss = self.trn_crit(logits, labels) self.optim.zero_grad() loss.backward() if self.cfg.clipping: timm.utils.adaptive_clip_grad(self.model.parameters()) self.optim.step() else: loss = self.val_crit(logits, labels) if self.cfg.nosiy_elimination: logit_preds = -F.log_softmax(logits, dim=-1) indexs = sample['idx'].detach().cpu().numpy() self.prediction_by_idx[indexs][:, :-1] += self.prediction_by_idx[indexs][:, :-1] * 0.2 self.prediction_by_idx[indexs][:, :-1] += logit_preds self.prediction_by_idx[indexs][:, -1] = labels.detach().cpu().numpy() batch_acc = self.accuracy(logits, labels) batch_f1 = self.f1_score(logits, labels) result = { 'logit': logits, 'loss': loss, 'batch_acc': batch_acc, 'batch_f1' : batch_f1 } return result
def train_step(self, batch: dict, criterion: nn.Module): text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] pitch = batch["pitch"] if self.args.use_pitch else None d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] durations = batch["durations"] aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} # forward pass outputs = self.forward(text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input) # use aligner's output as the duration target if self.use_aligner: durations = outputs["o_alignment_dur"] # use float32 in AMP with autocast(enabled=False): # compute loss loss_dict = criterion( decoder_output=outputs["model_outputs"], decoder_target=mel_input, decoder_output_lens=mel_lengths, dur_output=outputs["durations_log"], dur_target=durations, pitch_output=outputs["pitch_avg"] if self.use_pitch else None, pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, input_lens=text_lengths, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, ) # compute duration error durations_pred = outputs["durations"] duration_error = torch.abs( durations - durations_pred).sum() / text_lengths.sum() loss_dict["duration_error"] = duration_error return outputs, loss_dict
def run_train_iter(self, batch_data=None, data_time=None): assert self.model.training, '[CommonEngine] model was changed to eval model!' if self.batch_processor is not None: self.output = self.batch_processor(batch_data) else: with autocast(enabled=is_mixed_precision()): self.model_output = self.model( batch_data, cur_epoch=getattr(self, 'epoch', None), cur_iter=self.iter, inner_iter=getattr(self, 'inner_iter', None)) self.output = self.loss_fn(self.model_output) metrics_dict = {'data_time': data_time} metrics_dict.update(self.output) write_metrics(metrics_dict)
def forward(self, z_obj, z_cam_mid, z_obj_mid, camera): with autocast(enabled=self.training): num_views = z_obj.shape[1] h = z_obj[:, 0] if self.conv_module == EqualizedConv2d: # Concatenate pixel coords if 2d. coords = utils.get_normalized_pixel_coords(h) else: coords = utils.get_normalized_voxel_coords(h) for i in range(1, num_views): x = torch.cat((z_obj[:, i], coords), dim=1) h = self.gru(x, h) h = h.unsqueeze(1) return h, {}
def forward(self, cosine, label): with autocast(): # --------------------------- cos(theta) & phi(theta) --------------------------- sine = torch.sqrt((torch.sub(1.0, cosine * cosine)).clamp(0, 1)) phi = torch.mul(cosine, self.cos_m) - torch.mul(sine, self.sin_m) phi = torch.where(cosine > self.th, phi, cosine - self.mm) # --------------------------- convert label to one-hot --------------------------- one_hot = torch.zeros(cosine.size(), device='cuda') one_hot.scatter_(1, label.view(-1, 1).long(), 1) # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- output = (one_hot * phi) + ( (1.0 - one_hot) * cosine ) # you can use torch.where if your torch.__version__ is 0.4 output *= self.s loss = self.crit(output, label) if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss
def run_step(self): assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!" if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), f"[{self.__class__.__name__}] CUDA is required for AMP training!" from torch.cuda.amp.autocast_mode import autocast start = time.perf_counter() # load data tgt_inputs = self.pseudo_tgt_train_loader.next() def _parse_data(inputs): imgs, _, pids, _ = inputs return imgs.cuda(), pids.cuda() # process inputs t_inputs, t_targets = _parse_data(tgt_inputs) data_time = time.perf_counter() - start def _forward(): outputs = self.model(t_inputs) f_out_t = outputs['features'] p_out_t = outputs['pred_class_logits'][:, :self.num_clusters] loss_dict = {} loss_ce = cross_entropy_loss(pred_class_outputs=p_out_t, gt_classes=t_targets, eps=self.cfg.MODEL.LOSSES.CE.EPSILON, alpha=self.cfg.MODEL.LOSSES.CE.ALPHA) loss_dict.update({'loss_ce': loss_ce}) if 'TripletLoss' in self.cfg.MODEL.LOSSES.NAME: loss_tri = triplet_loss(f_out_t, t_targets, margin=0.0, norm_feat=True, hard_mining=False) loss_dict.update({'loss_tri': loss_tri}) return loss_dict if self.cfg.SOLVER.AMP.ENABLED: with autocast(): loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() losses.backward() self._write_metrics(loss_dict, data_time) self.optimizer.step() if isinstance(self.param_wrapper, ContiguousParams): self.param_wrapper.assert_buffer_is_valid()
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. Args: batch (Dict): Input tensors. criterion (nn.Module): Loss layer designed for the model. optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. Returns: Tuple[Dict, Dict]: Model ouputs and computed losses. """ # pylint: disable=attribute-defined-outside-init if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") if self.args.freeze_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False if hasattr(self, "emb_l"): for param in self.emb_l.parameters(): param.requires_grad = False if self.args.freeze_PE: for param in self.posterior_encoder.parameters(): param.requires_grad = False if self.args.freeze_DP: for param in self.duration_predictor.parameters(): param.requires_grad = False if self.args.freeze_flow_decoder: for param in self.flow.parameters(): param.requires_grad = False if self.args.freeze_waveform_decoder: for param in self.waveform_decoder.parameters(): param.requires_grad = False if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_lengths = batch["mel_lengths"] linear_input = batch["linear_input"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] language_ids = batch["language_ids"] waveform = batch["waveform"] # generator pass outputs = self.forward( text_input, text_lengths, linear_input.transpose(1, 2), mel_lengths, waveform.transpose(1, 2), aux_input={ "d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids }, ) # cache tensors for the discriminator self.y_disc_cache = None self.wav_seg_disc_cache = None self.y_disc_cache = outputs["model_outputs"] self.wav_seg_disc_cache = outputs["waveform_seg"] # compute discriminator scores and features outputs["scores_disc_fake"], outputs[ "feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc( outputs["model_outputs"], outputs["waveform_seg"]) # compute losses with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( waveform_hat=outputs["model_outputs"].float(), waveform=outputs["waveform_seg"].float(), z_p=outputs["z_p"].float(), logs_q=outputs["logs_q"].float(), m_p=outputs["m_p"].float(), logs_p=outputs["logs_p"].float(), z_len=mel_lengths, scores_disc_fake=outputs["scores_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], loss_duration=outputs["loss_duration"], use_speaker_encoder_as_loss=self.args. use_speaker_encoder_as_loss, gt_spk_emb=outputs["gt_spk_emb"], syn_spk_emb=outputs["syn_spk_emb"], ) elif optimizer_idx == 1: # discriminator pass outputs = {} # compute scores and features outputs["scores_disc_fake"], _, outputs[ "scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( outputs["scores_disc_real"], outputs["scores_disc_fake"], ) return outputs, loss_dict
def run_step(self): assert self.model.training, f"[{self.__class__.__name__}] model was changed to eval mode!" if self.cfg.SOLVER.AMP.ENABLED: assert torch.cuda.is_available( ), f"[{self.__class__.__name__}] CUDA is required for AMP training!" from torch.cuda.amp.autocast_mode import autocast start = time.perf_counter() # load data src_inputs = self.src_train_loader.next() tgt_inputs = self.pseudo_tgt_train_loader.next() # src_inputs = next(self.src_load_iter) # tgt_inputs = next(self.tgt_load_iter) def _parse_data(inputs): # print(len(inputs)) # for i in range(len(inputs)): # print(i, type(inputs[i]), inputs[i]) imgs, _, pids, _, indices = inputs return imgs.cuda(), pids.cuda(), indices # process inputs s_inputs, s_targets, s_indices = _parse_data(src_inputs) t_inputs, t_targets, t_indices = _parse_data(tgt_inputs) # print('src', s_targets, s_indices) # print('tgt', t_targets, t_indices) # exit() # arrange batch for domain-specific BNP device_num = torch.cuda.device_count() B, C, H, W = s_inputs.size() def reshape(inputs): return inputs.view(device_num, -1, C, H, W) s_inputs, t_inputs = reshape(s_inputs), reshape(t_inputs) inputs = torch.cat((s_inputs, t_inputs), 1).view(-1, C, H, W) data_time = time.perf_counter() - start def _forward(): outputs = self.model(inputs) if isinstance(outputs, dict): f_out = outputs['features'] else: f_out = outputs # de-arrange batch f_out = f_out.view(device_num, -1, f_out.size(-1)) f_out_s, f_out_t = f_out.split(f_out.size(1) // 2, dim=1) f_out_s, f_out_t = f_out_s.contiguous().view( -1, f_out.size(-1)), f_out_t.contiguous().view(-1, f_out.size(-1)) # compute loss with the hybrid memory # with autocast(enabled=False): loss_s = self.hm(f_out_s, s_targets) loss_t = self.hm(f_out_t, t_indices + self.src_pid_nums) loss_dict = {'loss_s': loss_s, 'loss_t': loss_t} return loss_dict if self.cfg.SOLVER.AMP.ENABLED: with autocast(): loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: loss_dict = _forward() losses = sum(loss_dict.values()) self.optimizer.zero_grad() losses.backward() self._write_metrics(loss_dict, data_time) self.optimizer.step() if isinstance(self.param_wrapper, ContiguousParams): self.param_wrapper.assert_buffer_is_valid()
def autocast(self): return autocast(enabled=self.enabled)
def train(train_loader, model, criterion, optimizer, epoch, args, scaler=None): # train for one epoch batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() psnr_out = AverageMeter() # switch to train mode model.train() end = time.time() if args.lr_policy == 'naive': local_lr = adjust_learning_rate_naive(optimizer, epoch, args) elif args.lr_policy == 'step': local_lr = adjust_learning_rate(optimizer, epoch, args) elif args.lr_policy == 'epoch_poly': local_lr = adjust_learning_rate_epoch_poly(optimizer, epoch, args) for i, (target, input_group) in enumerate(train_loader): # set random task task_id = random.randint(0, 5) if not args.task else task_map[args.task] input = input_group[task_id] model.module.set_task(task_id) #print(f"Iter {i}, task_id: {task_id}") #for m in model.module.modules(): # if isinstance(m, ) #print(m.weight.device) global_iter = epoch * args.epoch_size + i if args.lr_policy == 'iter_poly': local_lr = adjust_learning_rate_poly(optimizer, global_iter, args) elif args.lr_policy == 'cosine': local_lr = adjust_learning_rate_cosine(optimizer, global_iter, args) # measure data loading time data_time.update(time.time() - end) if args.gpu is not None: input = input.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) target = target.cuda() if scaler is None: # compute output output = model(input) #print(output.device, target.device) loss = criterion(output, target) else: with autocast(): # compute output output = model(input) #print(output.device, target.device) loss = criterion(output, target) # measure accuracy and record loss output = (output * 0.5 + 0.5) * 255. target = (target * 0.5 + 0.5) * 255. psnr = PSNR()(output, target) losses.update(loss.item(), input.size(0)) psnr_out.update(psnr.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() if scaler is None: # compute gradient and do SGD step loss.backward() optimizer.step() else: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'PSNR {psnr.val:.3f} ({psnr.avg:.3f})\t' 'LR: {lr: .6f}'.format(epoch, i, args.epoch_size, batch_time=batch_time, data_time=data_time, loss=losses, psnr=psnr_out, lr=local_lr))
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. Args: batch (Dict): Input tensors. criterion (nn.Module): Loss layer designed for the model. optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. Returns: Tuple[Dict, Dict]: Model ouputs and computed losses. """ # pylint: disable=attribute-defined-outside-init if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_lengths = batch["mel_lengths"] linear_input = batch["linear_input"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] waveform = batch["waveform"] # generator pass outputs = self.forward( text_input, text_lengths, linear_input.transpose(1, 2), mel_lengths, aux_input={ "d_vectors": d_vectors, "speaker_ids": speaker_ids }, ) # cache tensors for the discriminator self.y_disc_cache = None self.wav_seg_disc_cache = None self.y_disc_cache = outputs["model_outputs"] wav_seg = segment( waveform.transpose(1, 2), outputs["slice_ids"] * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, ) self.wav_seg_disc_cache = wav_seg outputs["waveform_seg"] = wav_seg # compute discriminator scores and features ( outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"], ) = self.disc(outputs["model_outputs"], wav_seg) # compute losses with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( waveform_hat=outputs["model_outputs"].float(), waveform=wav_seg.float(), z_p=outputs["z_p"].float(), logs_q=outputs["logs_q"].float(), m_p=outputs["m_p"].float(), logs_p=outputs["logs_p"].float(), z_len=mel_lengths, scores_disc_fake=outputs["scores_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], loss_duration=outputs["loss_duration"], ) elif optimizer_idx == 1: # discriminator pass outputs = {} # compute scores and features outputs["scores_disc_fake"], _, outputs[ "scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache) # compute loss with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( outputs["scores_disc_real"], outputs["scores_disc_fake"], ) return outputs, loss_dict
def run_iteration(self, batch, train, is_step): self.mark_time() # Update depth criterion k if applicable. if 'hard_' in self.g_depth_recon_loss_type: self._g_depth_recon_criterion.k = int( self._g_depth_recon_k_scheduler.get(self.epoch)) batch = process_batch(batch, self.cube_size, self.camera_dist, self._sculptor.in_size, self.device, self.random_orientation) if self.reconstruct_input: recon_camera = Camera.vcat( (batch['in_gt']['camera'], batch['out_gt']['camera']), batch_size=self.batch_size) recon_mask = torch.cat( (batch['in_gt']['mask'], batch['out_gt']['mask']), dim=1) recon_image = torch.cat( (batch['in_gt']['image'], batch['out_gt']['image']), dim=1) recon_depth = torch.cat( (batch['in_gt']['depth'], batch['out_gt']['depth']), dim=1) else: recon_camera = batch['out_gt']['camera'] recon_mask = batch['out_gt']['mask'] recon_image = batch['out_gt']['image'] recon_depth = batch['out_gt']['depth'] if not self.color_random_background or self.crop_random_background: batch['in']['image'] = batch['in']['image'] * batch['in']['mask'] if not self.depth_random_background or self.crop_random_background: batch['in']['depth'] = mask_normalized_depth( batch['in']['depth'], batch['in']['mask']) depth_in = None if self.generator_input_depth: depth_noise = self._depth_noise_dist.sample( batch['in']['depth'].size()).to(self.device) depth_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1) data_process_time = self.mark_time() with autocast(): # Evaluate generator. z_obj, z_extra = self._sculptor.encode( self._fuser, camera=batch['in']['camera'], color=batch['in']['image'], depth=depth_in, mask=batch['in']['mask'], data_parallel=self.data_parallel) fake_image, fake_depth, fake_mask, fake_mask_logits, fake_vox_depth = \ self._run_photographer(z_obj, recon_camera, recon_mask) if 'blend_weights' in z_extra: z_weights = z_extra['blend_weights'] else: z_weights = None # Train discriminator. if self._discriminator: d_real, d_fake_d, d_fake_g = self._run_discriminator( fake_image, fake_depth, fake_mask, recon_image, recon_depth, recon_mask) loss_d_real = multiscale_lsgan_loss(d_real, 1) loss_d_fake = multiscale_lsgan_loss(d_fake_d, 0) loss_d = loss_d_real + loss_d_fake loss_g_gan = multiscale_lsgan_loss(d_fake_g, 1) if train: loss_d.backward() if is_step: self._optimizers['discriminator'].step() self.plotter.put_scalar('loss/discriminator/real', loss_d_real) self.plotter.put_scalar('loss/discriminator/fake', loss_d_fake) self.plotter.put_scalar('loss/discriminator/total', loss_d) else: loss_g_gan = torch.tensor(0.0, device=self.device) # Train generator. if self.predict_color: loss_g_color_recon = reduce_loss( self._g_color_recon_criterion(fake_image, recon_image)) else: loss_g_color_recon = torch.tensor(0.0, device=self.device) if self.predict_depth or self.use_occlusion_depth: loss_g_depth_recon = reduce_loss( self._g_depth_recon_criterion(fake_depth, recon_depth)) else: loss_g_depth_recon = torch.tensor(0.0, device=self.device) if self.predict_mask: if self.g_mask_recon_loss_type == 'binary_cross_entropy': y_mask = fake_mask_logits else: y_mask = fake_mask loss_g_mask_recon = reduce_loss( self._g_mask_recon_criterion(y_mask, recon_mask)) loss_g_mask_beta = beta_prior_loss( fake_mask, alpha=self.g_mask_beta_loss_param, beta=self.g_mask_beta_loss_param) else: loss_g_mask_recon = torch.tensor(0.0, device=self.device) loss_g_mask_beta = torch.tensor(0.0, device=self.device) loss_g = (self.g_gan_loss_weight * loss_g_gan + self.g_color_recon_loss_weight * loss_g_color_recon + self.g_depth_recon_loss_weight * loss_g_depth_recon + self.g_mask_recon_loss_weight * loss_g_mask_recon + self.g_mask_beta_loss_weight * loss_g_mask_beta) / self.batch_groups if train: if self.kwargs.get('use_amp', False): self._scaler.scale(loss_g).backward() else: loss_g.backward() if is_step: if self.kwargs.get('use_amp', False): self._scaler.step(self._optimizers['generator']) self._scaler.update() else: self._optimizers['generator'].step() with torch.no_grad(): if self.predict_depth: self.plotter.put_scalar('error/depth/l1', F.l1_loss(fake_depth, recon_depth)) if self.reconstruct_input: self.plotter.put_scalar( 'error/depth/input_l1', F.l1_loss(fake_depth[:, :self.num_input_views], batch['in_gt']['depth'])) self.plotter.put_scalar( 'error/depth/output_l1', F.l1_loss(fake_depth[:, self.num_input_views:], batch['out_gt']['depth'])) if self.predict_mask: self.plotter.put_scalar( 'error/mask/cross_entropy', F.binary_cross_entropy_with_logits(fake_mask_logits, recon_mask)) self.plotter.put_scalar('error/mask/l1', F.l1_loss(fake_mask, recon_mask)) compute_time = self.mark_time() self.plotter.put_scalar('loss/generator/gan', loss_g_gan) self.plotter.put_scalar('loss/generator/recon/color', loss_g_color_recon) self.plotter.put_scalar('loss/generator/recon/depth', loss_g_depth_recon) self.plotter.put_scalar('loss/generator/recon/mask', loss_g_mask_recon) self.plotter.put_scalar('loss/generator/recon/mask_beta', loss_g_mask_beta) self.plotter.put_scalar('loss/generator/total', loss_g) self.plotter.put_scalar('params/input_noise_weight', self.input_noise_weight) if hasattr(self._g_depth_recon_criterion, 'k'): self.plotter.put_scalar('params/depth_loss_k', self._g_depth_recon_criterion.k) self.plotter.put_scalar('time/data_process', data_process_time) self.plotter.put_scalar('time/compute', compute_time) plot_scalar_time = self.mark_time() self.plotter.put_scalar('time/plot/scalars', plot_scalar_time) if self.plotter.is_it_time_yet('histogram'): if self.predict_color: self.plotter.put_histogram('image_fake', fake_image) self.plotter.put_histogram('image_real', recon_image) if self.predict_mask: self.plotter.put_histogram('mask_fake', fake_mask) self.plotter.put_histogram('z_obj', z_obj) if z_weights is not None: self.plotter.put_histogram('z_weights', z_weights) plot_histogram_time = self.mark_time() self.plotter.put_scalar('time/plot/histogram', plot_histogram_time) if self.plotter.is_it_time_yet('show'): self.plotter.put_image( 'inputs', viz.make_grid([ gan_denormalize(batch['in']['image']), viz.colorize_depth(batch['in']['depth']) if self.generator_input_depth else None, viz.colorize_tensor(batch['in']['mask']) if self.generator_input_mask else None, ], row_size=4, stride=2, output_size=64)) with torch.no_grad(): self.plotter.put_image( 'reconstruction', viz.make_grid([ gan_denormalize(recon_image), gan_denormalize(fake_image) if (fake_image is not None) else None, viz.colorize_depth(recon_depth), viz.colorize_depth(fake_depth) if (fake_depth is not None) else None, viz.colorize_tensor( (recon_depth.cpu() - fake_depth.cpu()).abs()) if (fake_depth is not None) else None, viz.colorize_tensor(recon_mask), viz.colorize_tensor(fake_mask) if (fake_mask is not None) else None, viz.colorize_tensor( (recon_mask.cpu() - fake_mask.cpu()).abs()) if (fake_mask is not None) else None, ], stride=8)) plot_images_time = self.mark_time() self.plotter.put_scalar('time/plot/images', plot_images_time)
def trn_step(self, epoch, sample_l, sample_u, scaler=None): self.optim.zero_grad() images_l = sample_l['image'].to(self.cfg.device) images_o = sample_u["image_ori"].to(self.cfg.device) images_a = sample_u["image_aug"].to(self.cfg.device) labels = sample_l['label'].to(self.cfg.device) batch_s = images_l.size(0) images_t = torch.cat([images_l, images_o, images_a]) if scaler is not None: with autocast(): logits_t = self.model(images_t) logits_l = logits_t[:batch_s] logits_o, logits_a = logits_t[batch_s:].chunk(2) del logits_t preds_o = F.softmax(logits_o, dim=-1).detach() preds_a = F.log_softmax(logits_a, dim=-1) kl_loss = F.kl_div(preds_a, preds_o, reduction='none') kl_loss = torch.mean(torch.sum(kl_loss, dim=-1)) l_loss = self.trn_crit(logits_l, labels) if self.cfg.ratio_mode == 'constant': t_loss = l_loss + self.cfg.ratio * torch.mean(kl_loss) elif self.cfg.ratio_mode == "gradual": t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * torch.mean( kl_loss) + l_loss scaler.scale(t_loss).backward() # clipping point -> batchnorm을 대체하는 역할 AGC scaler.unscale_(self.optim) if self.cfg.clipping: timm.utils.adaptive_clip_grad(self.model.parameters()) scaler.step(self.optim) scaler.update() else: logits_t = self.model(images_t) logits_l = logits_t[:batch_s] logits_o, logits_a = logits_t[batch_s:].chunk(2) del logits_t preds_o = F.softmax(logits_o, dim=-1).detach() preds_a = F.log_softmax(logits_a, dim=-1) kl_loss = F.kl_div(preds_a, preds_o, reduction='none') kl_loss = torch.mean(torch.sum(kl_loss, dim=-1)) l_loss = self.trn_crit(logits_l, labels) if self.cfg.ratio_mode == 'constant': t_loss = l_loss + self.cfg.ratio * kl_loss elif self.cfg.ratio_mode == "gradual": t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * kl_loss + l_loss t_loss.backward() if self.cfg.clipping: timm.utils.adaptive_clip_grad(self.model.parameters()) self.optim.step() batch_acc = self.accuracy(logits_l, labels) batch_f1 = self.f1_score(logits_l, labels) result = { 'l_loss': l_loss, 't_loss': t_loss, 'kl_loss': kl_loss, 'batch_acc': batch_acc, 'batch_f1': batch_f1 } return result