Exemplo n.º 1
0
	def train(self):
		loss_epoch = 0.
		num_batches = 0
		[e.train() for e in encoders], [m.train() for m in mesh_updates]

		# Train loop
		for i, data in enumerate(tqdm(dataloader_train), 0):
			optimizer.zero_grad()
			
			###############################
			####### data creation #########
			###############################
			tgt_points = data['points'].to(args.device)
			inp_images = data['imgs'].to(args.device)
			cam_mat = data['cam_mat'].to(args.device)
			cam_pos = data['cam_pos'].to(args.device)
			if (tgt_points.shape[0]!=args.batch_size) and  (inp_images.shape[0]!=args.batch_size)  \
				and (cam_mat.shape[0]!=args.batch_size) and  (cam_pos.shape[0]!=args.batch_size) : 
				continue
			surf_loss, edge_loss, lap_loss, loss, f_loss = 0,0,0,0,0
			###############################
			########## inference ##########
			###############################
			img_features = [e(inp_images) for e in encoders]
			for bn in range(args.batch_size):
				reset_meshes(meshes)
				##### layer_1 ##### 
				pool_indices = get_pooling_index(meshes['init'][0].vertices, cam_mat[bn], cam_pos[bn], encoding_dims)
				projected_image_features = pooling(img_features[0], pool_indices, bn)
				full_vert_features = torch.cat((meshes['init'][0].vertices, projected_image_features), dim = 1)
				
				delta, future_features = mesh_updates[0](full_vert_features, meshes['adjs'][0])
				meshes['update'][0].vertices = (meshes['init'][0].vertices + delta.clone())
				future_features = split_meshes(meshes,future_features, 0)			



				##### layer_2 ##### 
				pool_indices = get_pooling_index(meshes['init'][1].vertices, cam_mat[bn], cam_pos[bn], encoding_dims)
				projected_image_features = pooling(img_features[1], pool_indices, bn)
				full_vert_features = torch.cat((meshes['init'][1].vertices, projected_image_features, future_features), dim = 1)
				
				delta, future_features = mesh_updates[1](full_vert_features, meshes['adjs'][1])
				meshes['update'][1].vertices = (meshes['init'][1].vertices + delta.clone())
				future_features = split_meshes(meshes,future_features, 1)	

				##### layer_3 ##### 
				pool_indices = get_pooling_index(meshes['init'][2].vertices, cam_mat[bn], cam_pos[bn], encoding_dims)
				projected_image_features = pooling(img_features[2], pool_indices, bn)
				full_vert_features = torch.cat((meshes['init'][2].vertices, projected_image_features, future_features), dim = 1)
				delta, future_features = mesh_updates[2](full_vert_features, meshes['adjs'][2])
				meshes['update'][2].vertices = (meshes['init'][2].vertices + delta.clone())
				

				if args.latent_loss:
					inds = data['adj_indices'][bn]
					vals = data['adj_values'][bn]
					gt_verts = data['verts'][bn].to(args.device)
					vert_len = gt_verts.shape[0]
					gt_adj = torch.sparse.FloatTensor(inds, vals, torch.Size([vert_len,vert_len])).to(args.device)

					predicted_latent = mesh_encoder(meshes['update'][2].vertices, meshes['adjs'][2])  
					gt_latent = mesh_encoder(gt_verts, gt_adj)  
					latent_loss = torch.mean(torch.abs(predicted_latent - gt_latent)) * .2


				###############################
				########## losses #############
				###############################
				surf_loss += (6000 * loss_surf(meshes, tgt_points[bn]) / float(args.batch_size))
				edge_loss += (300 *.6 * loss_edge(meshes) / float(args.batch_size))
				lap_loss  += (1500 * loss_lap(meshes) / float(args.batch_size))
				f_loss += nvl.metrics.point.f_score(.57*meshes['update'][2].sample(2466)[0],.57*tgt_points[bn],  extend=False) / float(args.batch_size)


				loss = surf_loss + edge_loss + lap_loss
				if args.latent_loss: 
					loss += latent_loss
			loss.backward()
			loss_epoch += float(surf_loss.item())

			# logging
			num_batches += 1
			if i % args.print_every == 0:
				message = f'[TRAIN] Epoch {self.cur_epoch:03d}, Batch {i:03d}:, Loss: {(surf_loss.item()):4.3f}, '
				message = message + f'Lap: {(lap_loss.item()):3.3f}, Edge: {(edge_loss.item()):3.3f}'
				message = message + f' F: {(f_loss.item()):3.3f}'
				if args.latent_loss: 
					message = message + f', Lat: {(latent_loss.item()):3.3f}'
				tqdm.write(message)

			optimizer.step()
		
		
		loss_epoch = loss_epoch / num_batches
		self.train_loss.append(loss_epoch)
		self.cur_epoch += 1
Exemplo n.º 2
0
	def validate(self):
		[e.eval() for e in encoders], [m.eval() for m in mesh_updates]
		with torch.no_grad():	
			num_batches = 0
			loss_epoch = 0.
			loss_f = 0 
			# Validation loop
			for i, data in enumerate(tqdm(dataloader_val), 0):
				optimizer.zero_grad()
				
				###############################
				####### data creation #########
				###############################
				tgt_points = data['points'].to(args.device)
				inp_images = data['imgs'].to(args.device)
				cam_mat = data['cam_mat'].to(args.device)
				cam_pos = data['cam_pos'].to(args.device)
				if (tgt_points.shape[0]!=args.batch_size) and  (inp_images.shape[0]!=args.batch_size)  \
				and (cam_mat.shape[0]!=args.batch_size) and  (cam_pos.shape[0]!=args.batch_size) : 
					continue
				surf_loss = 0
				###############################
				########## inference ##########
				###############################
				img_features = [e(inp_images) for e in encoders]
				for bn in range(args.batch_size):
					reset_meshes(meshes)
					##### layer_1 ##### 
					pool_indices = get_pooling_index(meshes['init'][0].vertices, cam_mat[bn], cam_pos[bn], encoding_dims)
					projected_image_features = pooling(img_features[0], pool_indices, bn)
					full_vert_features = torch.cat((meshes['init'][0].vertices, projected_image_features), dim = 1)
					
					delta, future_features = mesh_updates[0](full_vert_features, meshes['adjs'][0])
					meshes['update'][0].vertices = (meshes['init'][0].vertices + delta.clone())
					future_features = split_meshes(meshes,future_features, 0)			



					##### layer_2 ##### 
					pool_indices = get_pooling_index(meshes['init'][1].vertices, cam_mat[bn], cam_pos[bn], encoding_dims)
					projected_image_features = pooling(img_features[1], pool_indices, bn)
					full_vert_features = torch.cat((meshes['init'][1].vertices, projected_image_features, future_features), dim = 1)
					
					delta, future_features = mesh_updates[1](full_vert_features, meshes['adjs'][1])
					meshes['update'][1].vertices = (meshes['init'][1].vertices + delta.clone())
					future_features = split_meshes(meshes,future_features, 1)	

					##### layer_3 ##### 
					pool_indices = get_pooling_index(meshes['init'][2].vertices, cam_mat[bn], cam_pos[bn], encoding_dims)
					projected_image_features = pooling(img_features[2], pool_indices, bn)
					full_vert_features = torch.cat((meshes['init'][2].vertices, projected_image_features, future_features), dim = 1)

					delta, future_features = mesh_updates[2](full_vert_features, meshes['adjs'][2])
					meshes['update'][2].vertices = (meshes['init'][2].vertices + delta.clone())
					pred_points, _ = meshes['update'][2].sample(10000)
					###############################
					########## losses #############
					###############################

					surf_loss = 3000 * nvl.metrics.point.chamfer_distance(pred_points, tgt_points[bn])
					loss_f += (nvl.metrics.point.f_score(.57*meshes['update'][2].sample(2466)[0],.57*tgt_points[bn], extend=False).item() / float(args.batch_size))

					loss_epoch += (surf_loss.item()  / float(args.batch_size))

					# logging
				num_batches += 1
				if i % args.print_every == 0:
					out_loss = loss_epoch / float(num_batches)
					out_f_loss = loss_f / float(num_batches)
					tqdm.write(f'[VAL] Epoch {self.cur_epoch:03d}, Batch {i:03d}: loss: {out_loss:3.3f}, loss: {out_f_loss:3.3f}')
			
			out_f_loss = loss_f / float(num_batches)
			out_loss = loss_epoch / float(num_batches)
			tqdm.write(f'[VAL Total] Epoch {self.cur_epoch:03d}, Batch {i:03d}: loss: {out_loss:3.3f},  loss: {out_f_loss:3.3f}')

			self.val_loss.append(out_f_loss)
Exemplo n.º 3
0
    def validate(self):
        [e.eval() for e in encoders], [m.eval() for m in mesh_updates]
        with torch.no_grad():
            num_batches = 0
            loss_epoch = 0.
            f_score = 0
            # Validation loop
            for i, sample in enumerate(tqdm(dataloader_val), 0):
                data = sample['data']
                optimizer.zero_grad()

                # Data Creation
                tgt_points = data['points'].to(args.device)
                inp_images = data['images'].to(args.device)
                cam_mat = data['params']['cam_mat'].to(args.device)
                cam_pos = data['params']['cam_pos'].to(args.device)
                if (tgt_points.shape[0] != args.batch_size) and (inp_images.shape[0] != args.batch_size)  \
                        and (cam_mat.shape[0] != args.batch_size) and (cam_pos.shape[0] != args.batch_size):
                    continue
                surf_loss = 0

                # Inference
                img_features = [e(inp_images) for e in encoders]
                for bn in range(args.batch_size):
                    reset_meshes(meshes)

                    # Layer_1
                    pool_indices = get_pooling_index(
                        meshes['init'][0].vertices, cam_mat[bn], cam_pos[bn],
                        encoding_dims)
                    projected_image_features = pooling(img_features[0],
                                                       pool_indices, bn)
                    full_vert_features = torch.cat(
                        (meshes['init'][0].vertices, projected_image_features),
                        dim=1)

                    delta, future_features = mesh_updates[0](
                        full_vert_features, meshes['adjs'][0])
                    meshes['update'][0].vertices = (
                        meshes['init'][0].vertices + delta.clone())
                    future_features = split_meshes(meshes,
                                                   future_features,
                                                   0,
                                                   angle=ANGLE_THRESHOLD)

                    # Layer_2
                    pool_indices = get_pooling_index(
                        meshes['init'][1].vertices, cam_mat[bn], cam_pos[bn],
                        encoding_dims)
                    projected_image_features = pooling(img_features[1],
                                                       pool_indices, bn)
                    full_vert_features = torch.cat(
                        (meshes['init'][1].vertices, projected_image_features,
                         future_features),
                        dim=1)

                    delta, future_features = mesh_updates[1](
                        full_vert_features, meshes['adjs'][1])
                    meshes['update'][1].vertices = (
                        meshes['init'][1].vertices + delta.clone())
                    future_features = split_meshes(meshes,
                                                   future_features,
                                                   1,
                                                   angle=ANGLE_THRESHOLD)

                    # Layer_3
                    pool_indices = get_pooling_index(
                        meshes['init'][2].vertices, cam_mat[bn], cam_pos[bn],
                        encoding_dims)
                    projected_image_features = pooling(img_features[2],
                                                       pool_indices, bn)
                    full_vert_features = torch.cat(
                        (meshes['init'][2].vertices, projected_image_features,
                         future_features),
                        dim=1)

                    delta, future_features = mesh_updates[2](
                        full_vert_features, meshes['adjs'][2])
                    meshes['update'][2].vertices = (
                        meshes['init'][2].vertices + delta.clone())
                    pred_points, _ = meshes['update'][2].sample(10000)

                    # Losses
                    surf_loss = weights[
                        'surface'] * kal.metrics.point.chamfer_distance(
                            pred_points, tgt_points[bn])

                    # F-Score
                    f_score += (kal.metrics.point.f_score(
                        .57 * meshes['update'][2].sample(2466)[0],
                        .57 * tgt_points[bn],
                        extend=False).item() / args.batch_size)

                    loss_epoch += surf_loss.item() / args.batch_size

                # logging
                num_batches += 1
                if i % args.print_every == 0:
                    out_loss = loss_epoch / num_batches
                    out_f_score = f_score / num_batches
                    tqdm.write(
                        f'[VAL]\tEpoch {self.cur_epoch:03d}, Batch {i:03d}: F-Score: {out_f_score:3.3f}'
                    )

            out_loss = loss_epoch / num_batches
            out_f_score = f_score / num_batches
            tqdm.write(
                f'[VAL Total] Epoch {self.cur_epoch:03d}, Batch {i:03d}: F-Score: {out_f_score:3.3f}'
            )

            self.val_score[self.cur_epoch] = out_f_score
Exemplo n.º 4
0
    def train(self):
        loss_epoch = 0.
        num_batches = 0

        [e.train() for e in encoders], [m.train() for m in mesh_updates]

        # Train loop
        for i, sample in enumerate(tqdm(dataloader_train), 0):
            data = sample['data']
            optimizer.zero_grad()

            # Data Creation
            tgt_points = data['points'].to(args.device)
            inp_images = data['images'].to(args.device)
            cam_mat = data['params']['cam_mat'].to(args.device)
            cam_pos = data['params']['cam_pos'].to(args.device)
            if (tgt_points.shape[0] != args.batch_size) and (inp_images.shape[0] != args.batch_size) \
                    and (cam_mat.shape[0] != args.batch_size) and (cam_pos.shape[0] != args.batch_size):
                continue
            surf_loss, edge_loss, lap_loss, latent_loss, loss, f_score = 0, 0, 0, 0, 0, 0

            # Inference
            img_features = [e(inp_images) for e in encoders]
            for bn in range(args.batch_size):
                reset_meshes(meshes)

                # Layer_1
                pool_indices = get_pooling_index(meshes['init'][0].vertices,
                                                 cam_mat[bn], cam_pos[bn],
                                                 encoding_dims)
                projected_image_features = pooling(img_features[0],
                                                   pool_indices, bn)
                full_vert_features = torch.cat(
                    (meshes['init'][0].vertices, projected_image_features),
                    dim=1)

                delta, future_features = mesh_updates[0](full_vert_features,
                                                         meshes['adjs'][0])
                meshes['update'][0].vertices = (meshes['init'][0].vertices +
                                                delta.clone())
                future_features = split_meshes(meshes,
                                               future_features,
                                               0,
                                               angle=ANGLE_THRESHOLD)

                # Layer_2
                pool_indices = get_pooling_index(meshes['init'][1].vertices,
                                                 cam_mat[bn], cam_pos[bn],
                                                 encoding_dims)
                projected_image_features = pooling(img_features[1],
                                                   pool_indices, bn)
                full_vert_features = torch.cat(
                    (meshes['init'][1].vertices, projected_image_features,
                     future_features),
                    dim=1)

                delta, future_features = mesh_updates[1](full_vert_features,
                                                         meshes['adjs'][1])
                meshes['update'][1].vertices = (meshes['init'][1].vertices +
                                                delta.clone())
                future_features = split_meshes(meshes,
                                               future_features,
                                               1,
                                               angle=ANGLE_THRESHOLD)

                # Layer_3
                pool_indices = get_pooling_index(meshes['init'][2].vertices,
                                                 cam_mat[bn], cam_pos[bn],
                                                 encoding_dims)
                projected_image_features = pooling(img_features[2],
                                                   pool_indices, bn)
                full_vert_features = torch.cat(
                    (meshes['init'][2].vertices, projected_image_features,
                     future_features),
                    dim=1)
                delta, future_features = mesh_updates[2](full_vert_features,
                                                         meshes['adjs'][2])
                meshes['update'][2].vertices = (meshes['init'][2].vertices +
                                                delta.clone())

                if args.latent_loss:
                    inds = data['adj']['indices'][bn]
                    vals = data['adj']['values'][bn]
                    gt_verts = data['vertices'][bn].to(args.device)
                    vert_len = gt_verts.shape[0]
                    gt_adj = torch.sparse.FloatTensor(
                        inds, vals, torch.Size([vert_len,
                                                vert_len])).to(args.device)

                    predicted_latent = mesh_encoder(
                        meshes['update'][2].vertices, meshes['adjs'][2])
                    gt_latent = mesh_encoder(gt_verts, gt_adj)
                    latent_loss += weights['latent'] * torch.mean(
                        torch.abs(predicted_latent -
                                  gt_latent)) / args.batch_size

                # Losses
                surf_loss += weights['surface'] * loss_surf(
                    meshes, tgt_points[bn]) / args.batch_size
                edge_loss += weights['edge'] * loss_edge(
                    meshes) / args.batch_size
                lap_loss += weights['laplace'] * loss_lap(
                    meshes) / args.batch_size

                # F-Score
                f_score += kal.metrics.point.f_score(
                    .57 * tgt_points[bn],
                    .57 * meshes['update'][2].sample(2466)[0],
                    extend=False) / args.batch_size

                loss = surf_loss + edge_loss + lap_loss
                if args.latent_loss:
                    loss += latent_loss
            loss.backward()
            loss_epoch += float(loss.item())

            # logging
            num_batches += 1
            if i % args.print_every == 0:
                message = f'[TRAIN]\tEpoch {self.cur_epoch:03d}, Batch {i:03d} | Total Loss: {loss.item():4.3f} '
                message += f'Surf: {(surf_loss.item()):3.3f}, Lap: {(lap_loss.item()):3.3f}, '
                message += f'Edge: {(edge_loss.item()):3.3f}'
                if args.latent_loss:
                    message = message + f', Latent: {(latent_loss.item()):3.3f}'
                message = message + f', F-score: {(f_score.item()):3.3f}'
                tqdm.write(message)

            optimizer.step()

        loss_epoch = loss_epoch / num_batches
        self.train_loss[self.cur_epoch] = loss_epoch
Exemplo n.º 5
0
		inp_images = data['imgs'].to(args.device).unsqueeze(0)
		cam_mat = data['cam_mat'].to(args.device)
		cam_pos = data['cam_pos'].to(args.device)
		tgt_verts = data['verts'].to(args.device)
		tgt_faces = data['faces'].to(args.device)


		
		###############################
		########## inference ##########
		###############################
		img_features = [e(inp_images) for e in encoders]
		
		reset_meshes(meshes)
		##### layer_1 ##### 
		pool_indices = get_pooling_index(meshes['init'][0].vertices, cam_mat, cam_pos, encoding_dims)
		projected_image_features = pooling(img_features[0], pool_indices, 0)
		full_vert_features = torch.cat((meshes['init'][0].vertices, projected_image_features), dim = 1)
		
		delta, future_features = mesh_updates[0](full_vert_features, meshes['adjs'][0])
		meshes['update'][0].vertices = (meshes['init'][0].vertices + delta.clone())
		future_features = split_meshes(meshes,future_features, 0)			



		##### layer_2 ##### 
		pool_indices = get_pooling_index(meshes['init'][1].vertices, cam_mat, cam_pos, encoding_dims)
		projected_image_features = pooling(img_features[1], pool_indices, 0)
		full_vert_features = torch.cat((meshes['init'][1].vertices, projected_image_features, future_features), dim = 1)
		
		delta, future_features = mesh_updates[1](full_vert_features, meshes['adjs'][1])
Exemplo n.º 6
0
    def validate(self):
        encoder.eval(), [m.eval() for m in mesh_updates]
        with torch.no_grad():
            num_batches = 0
            loss_epoch = 0.
            f_loss = 0.

            # Validation loop
            for i, data in enumerate(tqdm(dataloader_val), 0):
                optimizer.zero_grad()

                # data creation
                tgt_points = data['points'].to(args.device)[0]
                inp_images = data['imgs'].to(args.device)
                cam_mat = data['cam_mat'].to(args.device)[0]
                cam_pos = data['cam_pos'].to(args.device)[0]

                ###############################
                ########## inference ##########
                ###############################
                img_features = encoder(inp_images)

                ##### layer_1 #####
                pool_indices = get_pooling_index(meshes['init'][0].vertices,
                                                 cam_mat, cam_pos,
                                                 encoding_dims)
                projected_image_features = pooling(img_features, pool_indices)
                full_vert_features = torch.cat(
                    (meshes['init'][0].vertices, projected_image_features),
                    dim=1)

                pred_verts, future_features = mesh_updates[0](
                    full_vert_features, meshes['adjs'][0])
                meshes['update'][0].vertices = pred_verts.clone()

                ##### layer_2 #####
                future_features = split(meshes, future_features, 0)
                pool_indices = get_pooling_index(meshes['init'][1].vertices,
                                                 cam_mat, cam_pos,
                                                 encoding_dims)
                projected_image_features = pooling(img_features, pool_indices)
                full_vert_features = torch.cat(
                    (meshes['init'][1].vertices, projected_image_features,
                     future_features),
                    dim=1)

                pred_verts, future_features = mesh_updates[1](
                    full_vert_features, meshes['adjs'][1])
                meshes['update'][1].vertices = pred_verts.clone()

                ##### layer_3 #####
                future_features = split(meshes, future_features, 1)
                pool_indices = get_pooling_index(meshes['init'][2].vertices,
                                                 cam_mat, cam_pos,
                                                 encoding_dims)
                projected_image_features = pooling(img_features, pool_indices)
                full_vert_features = torch.cat(
                    (meshes['init'][2].vertices, projected_image_features,
                     future_features),
                    dim=1)

                pred_verts, future_features = mesh_updates[2](
                    full_vert_features, meshes['adjs'][2])
                meshes['update'][2].vertices = pred_verts.clone()

                f_loss += kal.metrics.point.f_score(
                    meshes['update'][2].sample(2466)[0],
                    tgt_points,
                    extend=False)

                ###############################
                ########## losses #############
                ###############################
                surf_loss = 3000 * kal.metrics.point.chamfer_distance(
                    pred_verts.clone(), tgt_points)
                loss_epoch += surf_loss.item()

                # logging
                num_batches += 1
                if i % args.print_every == 0:
                    out_loss = loss_epoch / float(num_batches)
                    f_out_loss = f_loss / float(num_batches)
                    tqdm.write(
                        f'[VAL] Epoch {self.cur_epoch:03d}, Batch {i:03d}: loss: {out_loss:3.3f}, F: {(f_out_loss.item()):3.3f}'
                    )

            out_loss = loss_epoch / float(num_batches)
            f_out_loss = f_loss / float(num_batches)
            tqdm.write(
                f'[VAL Total] Epoch {self.cur_epoch:03d}, Batch {i:03d}: loss: {out_loss:3.3f}, F: {(f_out_loss.item()):3.3f}'
            )

            self.val_loss.append(out_loss)
Exemplo n.º 7
0
    def train(self):
        loss_epoch = 0.
        num_batches = 0
        encoder.train(), [m.train() for m in mesh_updates]

        # Train loop
        for i, data in enumerate(tqdm(dataloader_train), 0):
            optimizer.zero_grad()

            ###############################
            ####### data creation #########
            ###############################
            tgt_points = data['points'].to(args.device)[0]
            tgt_norms = data['normals'].to(args.device)[0]
            inp_images = data['imgs'].to(args.device)
            cam_mat = data['cam_mat'].to(args.device)[0]
            cam_pos = data['cam_pos'].to(args.device)[0]

            ###############################
            ########## inference ##########
            ###############################
            img_features = encoder(inp_images)

            ##### layer_1 #####
            pool_indices = get_pooling_index(meshes['init'][0].vertices,
                                             cam_mat, cam_pos, encoding_dims)
            projected_image_features = pooling(img_features, pool_indices)
            full_vert_features = torch.cat(
                (meshes['init'][0].vertices, projected_image_features), dim=1)

            pred_verts, future_features = mesh_updates[0](full_vert_features,
                                                          meshes['adjs'][0])
            meshes['update'][0].vertices = pred_verts.clone()

            ##### layer_2 #####
            future_features = split(meshes, future_features, 0)
            pool_indices = get_pooling_index(meshes['init'][1].vertices,
                                             cam_mat, cam_pos, encoding_dims)
            projected_image_features = pooling(img_features, pool_indices)
            full_vert_features = torch.cat(
                (meshes['init'][1].vertices, projected_image_features,
                 future_features),
                dim=1)

            pred_verts, future_features = mesh_updates[1](full_vert_features,
                                                          meshes['adjs'][1])
            meshes['update'][1].vertices = pred_verts.clone()

            ##### layer_3 #####
            future_features = split(meshes, future_features, 1)
            pool_indices = get_pooling_index(meshes['init'][2].vertices,
                                             cam_mat, cam_pos, encoding_dims)
            projected_image_features = pooling(img_features, pool_indices)
            full_vert_features = torch.cat(
                (meshes['init'][2].vertices, projected_image_features,
                 future_features),
                dim=1)

            pred_verts, future_features = mesh_updates[2](full_vert_features,
                                                          meshes['adjs'][2])
            meshes['update'][2].vertices = pred_verts.clone()

            ###############################
            ########## losses #############
            ###############################
            surf_loss = 3000 * loss_surf(meshes, tgt_points)
            edge_loss = 300 * loss_edge(meshes)
            lap_loss = 1500 * loss_lap(meshes)
            norm_loss = .5 * loss_norm(meshes, tgt_points, tgt_norms)
            loss = surf_loss + edge_loss + lap_loss + norm_loss
            loss.backward()
            loss_epoch += float(surf_loss.item())

            # logging
            num_batches += 1
            if i % args.print_every == 0:
                f_loss = kal.metrics.point.f_score(
                    meshes['update'][2].sample(2466)[0],
                    tgt_points,
                    extend=False)
                message = f'[TRAIN] Epoch {self.cur_epoch:03d}, Batch {i:03d}:, Loss: {(surf_loss.item()):4.3f}, '
                message = message + f'Lap: {(lap_loss.item()):3.3f}, Edge: {(edge_loss.item()):3.3f}, Norm: {(norm_loss.item()):3.3f}'
                message = message + f' F: {(f_loss.item()):3.3f}'
                tqdm.write(message)
            optimizer.step()

        loss_epoch = loss_epoch / num_batches
        self.train_loss.append(loss_epoch)
        self.cur_epoch += 1