def test(test_loader, model, criterion, device): model.eval() tot_loss = 0 # MSE tot_cd_dist = 0 # Chamfer dist # TODO: remove unnecessary dependencies (tqdm) for i, data in tqdm(enumerate(test_loader, 0), total=len(test_loader), smoothing=0.9, desc='test', dynamic_ncols=True): assert len(data) == 3 noised, clean, _ = data bs = len(noised) noised = noised.to(device) # [bs, npoints, 3] clean = clean.to(device) with torch.no_grad(): cleaned = model(noised) assert cleaned.size() == clean.size() loss = criterion(cleaned, clean) tot_loss += loss.item() * bs # evaluate also CD distance cleaned = cleaned.contiguous() clean = clean.contiguous() dist1, dist2, _, _ = NND.nnd(cleaned, clean) cd_dist = 50 * torch.mean(dist1) + 50 * torch.mean(dist2) tot_cd_dist += cd_dist.item() * bs tot_loss = tot_loss * 1.0 / len(test_loader.dataset) tot_cd_dist = tot_cd_dist * 1.0 / len(test_loader.dataset) return tot_loss, tot_cd_dist
def train_one_epoch(train_loader, model, optimizer, criterion, device): model.train() tot_loss = 0 tot_cd_dist = 0 # TODO: remove unnecessary dependencies (tqdm) for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader), smoothing=0.9, desc='train', dynamic_ncols=True): assert len(data) == 3, 'train: expected tuple: (noised, clean, cls)' noised, clean, _ = data bs = len(noised) noised = noised.to(device) # [bs, npoints, 3] clean = clean.to(device) # [bs, npoints, 3] cleaned = model(noised) assert cleaned.size() == clean.size() loss = criterion(cleaned, clean) optimizer.zero_grad() loss.backward() optimizer.step() # evaluate also CD distance cleaned = cleaned.contiguous() clean = clean.contiguous() dist1, dist2, _, _ = NND.nnd(cleaned, clean) cd_dist = 50 * torch.mean(dist1) + 50 * torch.mean(dist2) tot_loss += loss.item() * bs # MSE Loss tot_cd_dist += cd_dist.item() * bs # Chamfer Distance tot_loss = tot_loss * 1.0 / len(train_loader.dataset) tot_cd_dist = tot_cd_dist * 1.0 / len(train_loader.dataset) return tot_loss, tot_cd_dist
def reconstruction_loss(self, data, reconstructions): data_ = data.transpose(2, 1).contiguous() reconstructions_ = reconstructions.transpose(2, 1).contiguous() dist1, dist2 = NND.nnd(data_, reconstructions_) loss = (torch.mean(dist1)) + (torch.mean(dist2)) return loss
reconstructions = self.caps_decoder(latent_capsules) return reconstructions if __name__ == '__main__': USE_CUDA = True batch_size=2 #ORIGINAL IS 8 prim_caps_size=1024 prim_vec_size=16 latent_caps_size=32 latent_vec_size=16 num_points=2048 point_caps_ae = PointCapsNet(prim_caps_size,prim_vec_size,latent_caps_size,latent_vec_size,num_points) point_caps_ae=torch.nn.DataParallel(point_caps_ae).cuda() rand_data=torch.rand(batch_size,num_points, 3) rand_data = Variable(rand_data) rand_data = rand_data.transpose(2, 1) rand_data=rand_data.cuda() codewords,reconstruction=point_caps_ae(rand_data) rand_data_ = rand_data.transpose(2, 1).contiguous() reconstruction_ = reconstruction.transpose(2, 1).contiguous() dist1, dist2 = NND.nnd(rand_data_, reconstruction_) loss = (torch.mean(dist1)) + (torch.mean(dist2)) print(loss.item())
import torch import torch.nn as nn from torch.autograd import Variable #from modules.nnd import NNDModule import torch_nndistance as NND #dist = NNDModule() p1 = torch.rand(10, 1000, 3) p2 = torch.rand(10, 1500, 3) points1 = Variable(p1, requires_grad=True) points2 = Variable(p2) points1 = points1.cuda() points2 = points2.cuda() dist1, dist2 = NND.nnd(points1, points2) print(dist1, dist2) loss = torch.sum(dist1) print(loss) loss.backward() print(points1.grad, points2.grad) points1 = Variable(p1.cuda(), requires_grad=True) points2 = Variable(p2.cuda()) dist1, dist2 = NND.nnd(points1, points2) print(dist1, dist2) loss = torch.sum(dist1) print(loss) loss.backward() print(points1.grad, points2.grad)
def main_worker(): opt, io, tb = get_args() start_epoch = -1 start_time = time.time() BASE_DIR = os.path.dirname( os.path.abspath(__file__)) # python script folder ckt = None if len(opt.restart_from) > 0: ckt = torch.load(opt.restart_from) start_epoch = ckt['epoch'] - 1 # load configuration from file try: with open(opt.config) as cf: config = json.load(cf) except IOError as error: print(error) # backup relevant files shutil.copy(src=os.path.abspath(__file__), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=opt.config, dst=os.path.join(opt.save_dir, 'backup_code', 'config.json.backup')) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) torch.cuda.manual_seed_all(opt.manualSeed) io.cprint(f"Arguments: {str(opt)}") io.cprint(f"Configuration: {str(config)}") pnum = config['completion_trainer']['num_points'] class_choice = opt.class_choice # datasets + loaders if len(class_choice) > 0: class_choice = ''.join(opt.class_choice.split()).split( ",") # sanitize + split(",") io.cprint("Class choice list: {}".format(str(class_choice))) else: class_choice = None # Train on all classes! (if opt.class_choice=='') tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='train') te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='test') tr_loader = torch.utils.data.DataLoader(tr_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, drop_last=True) te_loader = torch.utils.data.DataLoader(te_dataset, batch_size=64, shuffle=True, num_workers=opt.workers) num_holes = int(opt.num_holes) crop_point_num = int(opt.crop_point_num) context_point_num = int(opt.context_point_num) # io.cprint("Num holes: {}".format(num_holes)) # io.cprint("Crop points num: {}".format(crop_point_num)) # io.cprint("Context points num: {}".format(context_point_num)) # io.cprint("Pool1 num points selected: {}".format(opt.pool1_points)) # io.cprint("Pool2 num points selected: {}".format(opt.pool2_points)) """" Models """ gl_encoder = Encoder(conf=config) generator = Generator(conf=config, pool1_points=int(opt.pool1_points), pool2_points=int(opt.pool2_points)) gl_encoder.apply(weights_init_normal) # affecting only non pretrained generator.apply(weights_init_normal) # not pretrained print("Encoder: ", gl_encoder) print("Generator: ", generator) if ckt is not None: io.cprint(f"Restart Training from epoch {start_epoch}.") gl_encoder.load_state_dict(ckt['gl_encoder_state_dict']) generator.load_state_dict(ckt['generator_state_dict']) else: io.cprint("Training Completion Task...") local_fe_fn = config['completion_trainer']['checkpoint_local_enco'] global_fe_fn = config['completion_trainer']['checkpoint_global_enco'] if len(local_fe_fn) > 0: local_enco_dict = torch.load(local_fe_fn)['model_state_dict'] # refactoring pretext-trained local dgcnn encoder state dict keys local_enco_dict = remove_prefix_dict( state_dict=local_enco_dict, to_remove_str='local_encoder.') loc_load_result = gl_encoder.local_encoder.load_state_dict( local_enco_dict, strict=False) io.cprint( f"Local FE pretrained weights - loading res: {str(loc_load_result)}" ) else: # Ablation experiments only io.cprint("Local FE pretrained weights - NOT loaded", color='r') if len(global_fe_fn) > 0: global_enco_dict = torch.load(global_fe_fn, )['global_encoder'] glob_load_result = gl_encoder.global_encoder.load_state_dict( global_enco_dict, strict=True) io.cprint( f"Global FE pretrained weights - loading res: {str(glob_load_result)}", color='b') else: # Ablation experiments only io.cprint("Global FE pretrained weights - NOT loaded", color='r') io.cprint("Num GPUs: " + str(torch.cuda.device_count()) + ", Parallelism: {}".format(opt.parallel)) if opt.parallel and torch.cuda.device_count() > 1: gl_encoder = torch.nn.DataParallel(gl_encoder) generator = torch.nn.DataParallel(generator) gl_encoder.to(device) generator.to(device) # Optimizers + schedulers opt_E = torch.optim.Adam( gl_encoder.parameters(), lr=config['completion_trainer']['enco_lr'], # def: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_E = torch.optim.lr_scheduler.StepLR( opt_E, step_size=config['completion_trainer']['enco_step'], # def: 25 gamma=0.5) opt_G = torch.optim.Adam( generator.parameters(), lr=config['completion_trainer']['gen_lr'], # def: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_G = torch.optim.lr_scheduler.StepLR( opt_G, step_size=config['completion_trainer']['gen_step'], # def: 40 gamma=0.5) if ckt is not None: opt_E.load_state_dict(ckt['optimizerE_state_dict']) opt_G.load_state_dict(ckt['optimizerG_state_dict']) sched_E.load_state_dict(ckt['schedulerE_state_dict']) sched_G.load_state_dict(ckt['schedulerG_state_dict']) if not opt.fps_centroids: # 5 viewpoints to crop around - same as in PFNet centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0], [-1, 1, 0]]) else: raise NotImplementedError('experimental') centroids = None io.cprint("Training.. \n") best_test = sys.float_info.max best_ep = -1 it = 0 # global iteration counter vis_folder = None for epoch in range(start_epoch + 1, opt.epochs): start_ep_time = time.time() count = 0.0 tot_loss = 0.0 tot_fine_loss = 0.0 tot_raw_loss = 0.0 gl_encoder = gl_encoder.train() generator = generator.train() for i, data in enumerate(tr_loader, 0): it += 1 points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts, raw_gts = [], [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): # points[m]: complete shape of size (N,3) # partial: partial point cloud to complete # fine_gt: missing part ground truth # raw_gt: missing part ground truth + frame points (where frame points are points included in partial) partial, fine_gt, raw_gt = crop_shape(points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1, "Should be no need to resample if not multiple holes case" # sampling without replacement choice = torch.randperm(partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) raw_gts.append(raw_gt) if i == 1 and epoch % opt.it_test == 0: # make some visualization vis_folder = os.path.join(opt.vis_dir, "epoch_{}".format(epoch)) safe_make_dirs([vis_folder]) print(f"ep {epoch} - Saving visualizations into: {vis_folder}") for j in range(len(partials)): np.savetxt(X=partials[j], fname=os.path.join(vis_folder, '{}_cropped.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=fine_gts[j], fname=os.path.join(vis_folder, '{}_fine_gt.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=raw_gts[j], fname=os.path.join(vis_folder, '{}_raw_gt.txt'.format(j)), fmt='%.5f', delimiter=';') partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to(device) # [B, 512, 3] raw_gts = torch.stack(raw_gts).to(device) # [B, 512 + context, 3] if i == 1: # sanity check print("[dbg]: partials: ", partials.size(), ' ', partials.device) print("[dbg]: fine grained gts: ", fine_gts.size(), ' ', fine_gts.device) print("[dbg]: raw grained gts: ", raw_gts.size(), ' ', raw_gts.device) gl_encoder.zero_grad() generator.zero_grad() feat = gl_encoder(partials) fake_fine, fake_raw = generator( feat ) # pred_fine (only missing part), pred_intermediate (missing + frame) # pytorch 1.2 compiled Chamfer (C2C) dist. assert fake_fine.size() == fine_gts.size( ), "Wrong input shapes to Chamfer module" if i == 0: if fake_raw.size() != raw_gts.size(): warnings.warn( "size dismatch for: raw_pred: {}, raw_gt: {}".format( str(fake_raw.size()), str(raw_gts.size()))) # fine grained prediction + gt fake_fine = fake_fine.contiguous() fine_gts = fine_gts.contiguous() # raw prediction + gt fake_raw = fake_raw.contiguous() raw_gts = raw_gts.contiguous() dist1, dist2, _, _ = NND.nnd( fake_fine, fine_gts) # fine grained loss computation dist1_raw, dist2_raw, _, _ = NND.nnd( fake_raw, raw_gts) # raw grained loss computation # standard C2C distance loss fine_loss = 100 * (0.5 * torch.mean(dist1) + 0.5 * torch.mean(dist2)) # raw loss: missing part + frame raw_loss = 100 * (0.5 * torch.mean(dist1_raw) + 0.5 * torch.mean(dist2_raw)) loss = fine_loss + opt.raw_weight * raw_loss # missing part pred loss + α * raw reconstruction loss loss.backward() opt_E.step() opt_G.step() tot_loss += loss.item() * B tot_fine_loss += fine_loss.item() * B tot_raw_loss += raw_loss.item() * B if it % 10 == 0: io.cprint( '[%d/%d][%d/%d]: loss: %.4f, fine CD: %.4f, interm. CD: %.4f' % (epoch, opt.epochs, i, len(tr_loader), loss.item(), fine_loss.item(), raw_loss.item())) # make visualizations if i == 1 and epoch % opt.it_test == 0: assert (vis_folder is not None and os.path.exists(vis_folder)) fake_fine = fake_fine.cpu().detach().data.numpy() fake_raw = fake_raw.cpu().detach().data.numpy() for j in range(len(fake_fine)): np.savetxt(X=fake_fine[j], fname=os.path.join( vis_folder, '{}_pred_fine.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=fake_raw[j], fname=os.path.join(vis_folder, '{}_pred_raw.txt'.format(j)), fmt='%.5f', delimiter=';') sched_E.step() sched_G.step() io.cprint( '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' % (epoch, opt.epochs, tot_loss * 1.0 / count, tot_fine_loss * 1.0 / count, tot_raw_loss * 1.0 / count)) tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_interm', tot_raw_loss * 1.0 / count, epoch) if epoch % opt.it_test == 0: torch.save( { 'type_exp': 'dgccn at local encoder', 'epoch': epoch + 1, 'epoch_train_loss': tot_loss * 1.0 / count, 'epoch_train_loss_raw': tot_raw_loss * 1.0 / count, 'epoch_train_loss_fine': tot_fine_loss * 1.0 / count, 'gl_encoder_state_dict': gl_encoder.module.state_dict() if isinstance( gl_encoder, nn.DataParallel) else gl_encoder.state_dict(), 'generator_state_dict': generator.module.state_dict() if isinstance( generator, nn.DataParallel) else generator.state_dict(), 'optimizerE_state_dict': opt_E.state_dict(), 'optimizerG_state_dict': opt_G.state_dict(), 'schedulerE_state_dict': sched_E.state_dict(), 'schedulerG_state_dict': sched_G.state_dict(), }, os.path.join(opt.models_dir, 'checkpoint_' + str(epoch) + '.pth')) if epoch % opt.it_test == 0: test_cd, count = 0.0, 0.0 for i, data in enumerate(te_loader, 0): points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts = [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): partial, fine_gt, _ = crop_shape(points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1 # sampling Without replacement choice = torch.randperm( partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to( device).contiguous() # [B, 512, 3] # TEST FORWARD # Considering only missing part prediction at Test Time gl_encoder.eval() generator.eval() with torch.no_grad(): feat = gl_encoder(partials) fake_fine, _ = generator(feat) fake_fine = fake_fine.contiguous() assert fake_fine.size() == fine_gts.size() dist1, dist2, _, _ = NND.nnd(fake_fine, fine_gts) cd_loss = 100 * (0.5 * torch.mean(dist1) + 0.5 * torch.mean(dist2)) test_cd += cd_loss.item() * B test_cd = test_cd * 1.0 / count io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' % (epoch, opt.epochs, test_cd), color="b") tb.add_scalar('Test/cd_loss', test_cd, epoch) is_best = test_cd < best_test best_test = min(best_test, test_cd) if is_best: # best model case best_ep = epoch io.cprint("New best test %.5f at epoch %d" % (best_test, best_ep)) shutil.copyfile(src=os.path.join( opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'), dst=os.path.join(opt.models_dir, 'best_model.pth')) io.cprint( '[%d/%d] Epoch time: %s' % (epoch, num_epochs, time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time)))) # Script ends hours, rem = divmod(time.time() - start_time, 3600) minutes, seconds = divmod(rem, 60) io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format( int(hours), int(minutes), seconds)) io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
import torch import torch_nndistance as NND device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") points1 = torch.rand(16, 2048, 3).to(device) points1.requires_grad = True points2 = torch.rand(16, 1024, 3).to(device) points2.requires_grad = True dist1, dist2, idx1, idx2 = NND.nnd(points1, points2) # print(dist1, dist2) print('dist1: ', dist1.size()) print('dist2: ', dist2.size()) loss = torch.sum(dist1) print(loss) loss.backward() print(points1.grad, points2.grad) print('ok - Test 1') print("points1 grad:\n", points1.grad) print("") print("points2 grad:\n", points2.grad) if points1.grad is not None and points2.grad is not None: print('ok - Test 2 Gradient') else: print('Fail - Test 2 Gradient')
####################################################### # (3) Update G network: maximize log(D(G(z))) ####################################################### point_netG.zero_grad() label.data.fill_(real_label) # foolish output = point_netD(fake) errG_D = criterion(output, label) errG_l2 = 0 fake = fake.squeeze( 1).contiguous() # [32, 1, 512, 3] -> [32, 512, 3] real_center = real_center.squeeze(1).contiguous() # print("dbg fake: ", fake.size()) # print("dbg real_center: ", real_center.size()) assert fake.size() == real_center.size(), "fail fake shape" d1, d2, _, _ = NND.nnd(fake, real_center) CD_LOSS = 100 * (0.5 * torch.mean(d1) + 0.5 * torch.mean(d2)) # computing also errG_l2 ''' fake center 1 ''' fake_center1 = fake_center1.contiguous() real_center_key1 = real_center_key1.contiguous() # print("dbg fake_center1: ", fake_center1.size()) # print("dbg real_center_key1: ", real_center_key1.size()) assert fake_center1.size() == real_center_key1.size( ), "fail fake 1 {}".format(str(fake_center1.size())) d1, d2, _, _ = NND.nnd(fake_center1, real_center_key1) cd_fake_1 = 100 * (0.5 * torch.mean(d1) + 0.5 * torch.mean(d2)) ''' fake center 2 ''' fake_center2 = fake_center2.contiguous() real_center_key2 = real_center_key2.contiguous()
def main_worker(): opt, io, tb = get_args() start_epoch = -1 start_time = time.time() ckt = None if len(opt.restart_from) > 0: ckt = torch.load(opt.restart_from) start_epoch = ckt['epoch'] - 1 # load configuration from file try: with open(opt.config) as cf: config = json.load(cf) except IOError as error: print(error) # backup relevant files shutil.copy(src=os.path.abspath(__file__), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=opt.config, dst=os.path.join(opt.save_dir, 'backup_code', 'config.json.backup')) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) torch.cuda.manual_seed_all(opt.manualSeed) io.cprint(f"Arguments: {str(opt)}") io.cprint(f"Configuration: {str(config)}") pnum = config['completion_trainer'][ 'num_points'] # number of points of complete pointcloud class_choice = opt.class_choice # config['completion_trainer']['class_choice'] # datasets + loaders if len(class_choice) > 0: class_choice = ''.join(opt.class_choice.split()).split( ",") # sanitize + split(",") io.cprint("Class choice list: {}".format(str(class_choice))) else: class_choice = None # training on all snpart classes tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='train') te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='test') tr_loader = torch.utils.data.DataLoader(tr_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, drop_last=True) te_loader = torch.utils.data.DataLoader(te_dataset, batch_size=64, shuffle=True, num_workers=opt.workers) num_holes = int(opt.num_holes) crop_point_num = int(opt.crop_point_num) context_point_num = int(opt.context_point_num) # io.cprint(f"Completion Setting:\n num classes {len(tr_dataset.cat.keys())}, num holes: {num_holes}, " # f"crop point num: {crop_point_num}, frame/context point num: {context_point_num},\n" # f"num points at pool1: {opt.pool1_points}, num points at pool2: {opt.pool2_points} ") # Models gl_encoder = Encoder(conf=config) generator = Generator(conf=config, pool1_points=int(opt.pool1_points), pool2_points=int(opt.pool2_points)) gl_encoder.apply( weights_init_normal) # affecting only non pretrained layers generator.apply(weights_init_normal) print("Encoder: ", gl_encoder) print("Generator: ", generator) if ckt is not None: # resuming training from intermediate checkpoint # restoring both encoder and generator state io.cprint(f"Restart Training from epoch {start_epoch}.") gl_encoder.load_state_dict(ckt['gl_encoder_state_dict']) generator.load_state_dict(ckt['generator_state_dict']) io.cprint("Whole model loaded from {}\n".format(opt.restart_from)) else: # training the completion model # load local and global encoder pretrained (ssl pretexts) weights io.cprint("Training Completion Task...") local_fe_fn = config['completion_trainer']['checkpoint_local_enco'] global_fe_fn = config['completion_trainer']['checkpoint_global_enco'] if len(local_fe_fn) > 0: local_enco_dict = torch.load(local_fe_fn, )['model_state_dict'] loc_load_result = gl_encoder.local_encoder.load_state_dict( local_enco_dict, strict=False) io.cprint( f"Local FE pretrained weights - loading res: {str(loc_load_result)}" ) else: # Ablation experiments only io.cprint("Local FE pretrained weights - NOT loaded", color='r') if len(global_fe_fn) > 0: global_enco_dict = torch.load(global_fe_fn, )['global_encoder'] glob_load_result = gl_encoder.global_encoder.load_state_dict( global_enco_dict, strict=True) io.cprint( f"Global FE pretrained weights - loading res: {str(glob_load_result)}", color='b') else: # Ablation experiments only io.cprint("Global FE pretrained weights - NOT loaded", color='r') io.cprint("Num GPUs: " + str(torch.cuda.device_count()) + ", Parallelism: {}".format(opt.parallel)) if opt.parallel: # TODO: implement DistributedDataParallel training assert torch.cuda.device_count() > 1 gl_encoder = torch.nn.DataParallel(gl_encoder) generator = torch.nn.DataParallel(generator) gl_encoder.to(device) generator.to(device) # Optimizers + schedulers opt_E = torch.optim.Adam( gl_encoder.parameters(), lr=config['completion_trainer']['enco_lr'], # default is: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_E = torch.optim.lr_scheduler.StepLR( opt_E, step_size=config['completion_trainer']['enco_step'], # default is: 25 gamma=0.5) opt_G = torch.optim.Adam( generator.parameters(), lr=config['completion_trainer']['gen_lr'], # default is: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_G = torch.optim.lr_scheduler.StepLR( opt_G, step_size=config['completion_trainer']['gen_step'], # default is: 40 gamma=0.5) if ckt is not None: # resuming training from intermediate checkpoint # restore optimizers state opt_E.load_state_dict(ckt['optimizerE_state_dict']) opt_G.load_state_dict(ckt['optimizerG_state_dict']) sched_E.load_state_dict(ckt['schedulerE_state_dict']) sched_G.load_state_dict(ckt['schedulerG_state_dict']) # crop centroids if not opt.fps_centroids: # 5 viewpoints to crop around - same crop procedure of PFNet - main paper centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0], [-1, 1, 0]]) else: raise NotImplementedError('experimental') centroids = None io.cprint('Centroids: ' + str(centroids)) # training loop io.cprint("Training.. \n") best_test = sys.float_info.max best_ep, glob_it = -1, 0 vis_folder = None for epoch in range(start_epoch + 1, opt.epochs): start_ep_time = time.time() count = 0.0 tot_loss = 0.0 tot_fine_loss = 0.0 tot_interm_loss = 0.0 gl_encoder = gl_encoder.train() generator = generator.train() for i, data in enumerate(tr_loader, 0): glob_it += 1 points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts, interm_gts = [], [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): partial, fine_gt, interm_gt = crop_shape( points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1 # sampling without replacement choice = torch.randperm(partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) interm_gts.append(interm_gt) if i == 1 and epoch % opt.it_test == 0: # make some visualization vis_folder = os.path.join(opt.vis_dir, "epoch_{}".format(epoch)) safe_make_dirs([vis_folder]) print(f"ep {epoch} - Saving visualizations into: {vis_folder}") for j in range(len(partials)): np.savetxt(X=partials[j], fname=os.path.join(vis_folder, '{}_partial.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=fine_gts[j], fname=os.path.join(vis_folder, '{}_fine_gt.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=interm_gts[j], fname=os.path.join( vis_folder, '{}_interm_gt.txt'.format(j)), fmt='%.5f', delimiter=';') partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to(device) # [B, 512, 3] interm_gts = torch.stack(interm_gts).to(device) # [B, 1024, 3] gl_encoder.zero_grad() generator.zero_grad() feat = gl_encoder(partials) pred_fine, pred_raw = generator(feat) # pytorch 1.2 compiled Chamfer (C2C) dist. assert pred_fine.size() == fine_gts.size() pred_fine, pred_raw = pred_fine.contiguous(), pred_raw.contiguous() fine_gts, interm_gts = fine_gts.contiguous( ), interm_gts.contiguous() dist1, dist2, _, _ = NND.nnd(pred_fine, fine_gts) # missing part pred loss dist1_raw, dist2_raw, _, _ = NND.nnd( pred_raw, interm_gts) # intermediate pred loss fine_loss = 50 * (torch.mean(dist1) + torch.mean(dist2) ) # chamfer is weighted by 100 interm_loss = 50 * (torch.mean(dist1_raw) + torch.mean(dist2_raw)) loss = fine_loss + opt.raw_weight * interm_loss loss.backward() opt_E.step() opt_G.step() tot_loss += loss.item() * B tot_fine_loss += fine_loss.item() * B tot_interm_loss += interm_loss.item() * B if glob_it % 10 == 0: header = "[%d/%d][%d/%d]" % (epoch, opt.epochs, i, len(tr_loader)) io.cprint('%s: loss: %.4f, fine CD: %.4f, interm. CD: %.4f' % (header, loss.item(), fine_loss.item(), interm_loss.item())) # make visualizations if i == 1 and epoch % opt.it_test == 0: assert (vis_folder is not None and os.path.exists(vis_folder)) pred_fine = pred_fine.cpu().detach().data.numpy() pred_raw = pred_raw.cpu().detach().data.numpy() for j in range(len(pred_fine)): np.savetxt(X=pred_fine[j], fname=os.path.join( vis_folder, '{}_pred_fine.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=pred_raw[j], fname=os.path.join(vis_folder, '{}_pred_raw.txt'.format(j)), fmt='%.5f', delimiter=';') sched_E.step() sched_G.step() io.cprint( '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' % (epoch, opt.epochs, tot_loss * 1.0 / count, tot_fine_loss * 1.0 / count, tot_interm_loss * 1.0 / count)) tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_interm', tot_interm_loss * 1.0 / count, epoch) if epoch % opt.it_test == 0: torch.save( { 'epoch': epoch + 1, 'epoch_train_loss': tot_loss * 1.0 / count, 'epoch_train_loss_raw': tot_interm_loss * 1.0 / count, 'epoch_train_loss_fine': tot_fine_loss * 1.0 / count, 'gl_encoder_state_dict': gl_encoder.module.state_dict() if isinstance( gl_encoder, nn.DataParallel) else gl_encoder.state_dict(), 'generator_state_dict': generator.module.state_dict() if isinstance( generator, nn.DataParallel) else generator.state_dict(), 'optimizerE_state_dict': opt_E.state_dict(), 'optimizerG_state_dict': opt_G.state_dict(), 'schedulerE_state_dict': sched_E.state_dict(), 'schedulerG_state_dict': sched_G.state_dict(), }, os.path.join(opt.models_dir, 'checkpoint_' + str(epoch) + '.pth')) if epoch % opt.it_test == 0: test_cd, count = 0.0, 0.0 for i, data in enumerate(te_loader, 0): points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts = [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): partial, fine_gt, _ = crop_shape(points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1 # sampling Without replacement choice = torch.randperm( partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to( device).contiguous() # [B, 512, 3] # TEST FORWARD # Considering only missing part prediction at Test Time gl_encoder.eval() generator.eval() with torch.no_grad(): feat = gl_encoder(partials) pred_fine, _ = generator(feat) pred_fine = pred_fine.contiguous() assert pred_fine.size() == fine_gts.size() dist1, dist2, _, _ = NND.nnd(pred_fine, fine_gts) cd_loss = 50 * (torch.mean(dist1) + torch.mean(dist2)) test_cd += cd_loss.item() * B test_cd = test_cd * 1.0 / count io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' % (epoch, opt.epochs, test_cd), color="b") tb.add_scalar('Test/cd_loss', test_cd, epoch) is_best = test_cd < best_test best_test = min(best_test, test_cd) if is_best: # best model case best_ep = epoch io.cprint("New best test %.5f at epoch %d" % (best_test, best_ep)) shutil.copyfile(src=os.path.join( opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'), dst=os.path.join(opt.models_dir, 'best_model.pth')) io.cprint( '[%d/%d] Epoch time: %s' % (epoch, opt.epochs, time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time)))) # Script ends hours, rem = divmod(time.time() - start_time, 3600) minutes, seconds = divmod(rem, 60) io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format( int(hours), int(minutes), seconds)) io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
frag2_batch = frag2_batch.squeeze().cuda() R1 = R1.squeeze().cuda() R2 = R2.squeeze().cuda() lrf1 = lrf1.squeeze().cuda() lrf2 = lrf2.squeeze().cuda() optimizer.zero_grad() f1, xtrans1, trans1, f2, xtrans2, trans2 = net(frag1_batch, frag2_batch) # hardest-contrastive loss lcontrastive, a, b, c = hardest_contrastive(f1, f2) # chamfer loss dist1, dist2 = NND.nnd( xtrans1.transpose(2, 1).contiguous(), xtrans2.transpose(2, 1).contiguous()) lchamf = .5 * (torch.mean(dist1) + torch.mean(dist2)) # combination of losses loss = lcontrastive + lchamf loss.backward() optimizer.step() writer.add_scalar('loss/train', loss.item(), n_iter) writer.add_scalar('hardest_contrastive/positive - train', torch.mean(a).item(), n_iter) writer.add_scalar('hardest_contrastive/negative1 - train', torch.mean(b[0]).item(), n_iter) writer.add_scalar('hardest_contrastive/negative2 - train', torch.mean(c[0]).item(), n_iter)