def _get_heatmap_seq_loss(self, heatmaps_seq): losses = [] for heatmaps in heatmaps_seq: losses.append( get_heatmap_penalty(heatmaps, self.cfg.heatmap_regularization)) return torch.sum(torch.stack(losses))
def validation_step(self, batch, batch_idx): data = batch img = data['image'] keypoints, heatmaps, reconstructed_img = self.forward(img) reconstruction_loss = F.mse_loss( img, reconstructed_img, reduction='sum') / 2.0 reconstruction_loss /= img.shape[0] heatmap_loss = get_heatmap_penalty(heatmaps, self.cfg.heatmap_regularization) loss = reconstruction_loss + heatmap_loss #loss = reconstruction_loss tqdm_dict = { 'val_loss': loss, 'val_recon_loss': reconstruction_loss, 'val_hmap_loss:': heatmap_loss } return { 'val_loss': loss, 'val_recon_loss': reconstruction_loss, 'val_hmap_loss': heatmap_loss / self.cfg.heatmap_regularization, 'progress_bar': tqdm_dict }
def training_step(self, batch, batch_idx): data = batch img = data['image'] keypoints, heatmaps, reconstructed_img = self.forward(img) reconstruction_loss = F.mse_loss( img, reconstructed_img, reduction='sum') / 2.0 reconstruction_loss /= img.shape[0] heatmap_loss = get_heatmap_penalty(heatmaps, self.cfg.heatmap_regularization) loss = reconstruction_loss + heatmap_loss log_dict = { 'recon_loss': reconstruction_loss, 'loss': loss, 'heatmap_loss': heatmap_loss } output = { 'loss': loss, 'recon_loss': reconstruction_loss, 'heatmap_loss': heatmap_loss / self.cfg.heatmap_regularization, 'log': log_dict } #print(reconstruction_loss.item(), heatmap_loss.item(), '\n') return output
def train_epoch(epoch, model_dict, cfg): models = model_dict['models'] images_to_keypoints_net, keypoints_to_images_net = models optimizer = model_dict['optimizer'] device = model_dict['device'] writer = model_dict['writer'] train_loader = model_dict['train_loader'] images_to_keypoints_net.train() keypoints_to_images_net.train() train_loss = 0.0 steps = 0 recon_loss, heatmap_loss = 0.0, 0.0 for batch_idx, data in enumerate(tqdm(train_loader)): optimizer.zero_grad() img = data['image'].to(device) keypoints, heatmaps = images_to_keypoints_net(img) reconstructed_img = keypoints_to_images_net(keypoints) reconstruction_loss = F.mse_loss(img, reconstructed_img, reduction='sum') reconstruction_loss /= img.shape[0] heatmap_loss = get_heatmap_penalty(heatmaps, cfg.heatmap_regularization) #print(heatmaps[0][0]) #loss = reconstruction_loss + heatmap_loss loss = reconstruction_loss loss.backward() train_loss += loss.item() recon_loss += reconstruction_loss.item() heatmap_loss += heatmap_loss.item() #orch.nn.utils.clip_grad_norm_(models.parameters(),cfg.clipnorm) optimizer.step() # if batch_idx % LOG_INTERVAL == 0: # print('Train Epoch: {} [{}]\t Recon Loss: {:.6f}'.format( # epoch, batch_idx, loss.item())) steps += 1 break writer.add_scalar("train_loss", train_loss / steps, epoch) writer.add_scalar("train_recon_loss", recon_loss / steps, epoch) writer.add_scalar("train_hetmap_loss", heatmap_loss / steps, epoch) print('\n====> Epoch: {} Average loss: {:.4f} heatmap_loss: {}'.format( epoch, train_loss / steps, heatmap_loss / steps)) path = SAVE_PATH + str(epoch) + ".pth" torch.save(models.state_dict(), path)
def _get_heatmap_seq_loss(self, heatmaps_seq): losses = [] num_seq = heatmaps_seq.shape[1] for i in range(num_seq): heatmaps = heatmaps_seq[:,i] losses.append(get_heatmap_penalty(heatmaps)) return torch.sum(torch.stack(losses))
'kernel_size': 3, 'padding': [[0, 0], [1, 1], [1, 1], [0, 0]], 'activation': tf.nn.leaky_relu, 'kernel_initializer': tf.keras.initializers.Constant(w_init) }) m2conv = tf.keras.layers.Conv2D( filters=64, kernel_size=1, padding='valid', activation=tf.nn.softplus, kernel_initializer=tf.keras.initializers.Constant(w_init)) a, axs = m1(xtor) axs.append(m1conv(axs[-1])) ax_pen = losses.get_heatmap_penalty(axs[-1], 1) tr = lambda x: x.permute(0, 2, 3, 1).detach().numpy() axs = [tr(ax) for ax in axs] a = tr(a) b, bxs = m2(xtf) bxs.append(m2conv(bxs[-1])) bx_pen = losses.get_heatmap_penalty_tf(bxs[-1], 1) bxs = [bx.numpy() for bx in bxs] b = b.numpy() for i, (ax, bx) in enumerate(zip(axs, bxs)): print('check', i, np.allclose( ax, bx,