Exemple #1
0
def test(test_loader, model, criterion, device):
    model.eval()
    tot_loss = 0  # MSE
    tot_cd_dist = 0  # Chamfer dist
    # TODO: remove unnecessary dependencies (tqdm)
    for i, data in tqdm(enumerate(test_loader, 0),
                        total=len(test_loader),
                        smoothing=0.9,
                        desc='test',
                        dynamic_ncols=True):
        assert len(data) == 3
        noised, clean, _ = data
        bs = len(noised)
        noised = noised.to(device)  # [bs, npoints, 3]
        clean = clean.to(device)

        with torch.no_grad():
            cleaned = model(noised)
            assert cleaned.size() == clean.size()
            loss = criterion(cleaned, clean)
            tot_loss += loss.item() * bs

            # evaluate also CD distance
            cleaned = cleaned.contiguous()
            clean = clean.contiguous()
            dist1, dist2, _, _ = NND.nnd(cleaned, clean)
            cd_dist = 50 * torch.mean(dist1) + 50 * torch.mean(dist2)
            tot_cd_dist += cd_dist.item() * bs

    tot_loss = tot_loss * 1.0 / len(test_loader.dataset)
    tot_cd_dist = tot_cd_dist * 1.0 / len(test_loader.dataset)
    return tot_loss, tot_cd_dist
Exemple #2
0
def train_one_epoch(train_loader, model, optimizer, criterion, device):
    model.train()
    tot_loss = 0
    tot_cd_dist = 0
    # TODO: remove unnecessary dependencies (tqdm)
    for i, data in tqdm(enumerate(train_loader, 0),
                        total=len(train_loader),
                        smoothing=0.9,
                        desc='train',
                        dynamic_ncols=True):
        assert len(data) == 3, 'train: expected tuple: (noised, clean, cls)'
        noised, clean, _ = data
        bs = len(noised)
        noised = noised.to(device)  # [bs, npoints, 3]
        clean = clean.to(device)  # [bs, npoints, 3]
        cleaned = model(noised)
        assert cleaned.size() == clean.size()
        loss = criterion(cleaned, clean)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # evaluate also CD distance
        cleaned = cleaned.contiguous()
        clean = clean.contiguous()
        dist1, dist2, _, _ = NND.nnd(cleaned, clean)
        cd_dist = 50 * torch.mean(dist1) + 50 * torch.mean(dist2)

        tot_loss += loss.item() * bs  # MSE Loss
        tot_cd_dist += cd_dist.item() * bs  # Chamfer Distance

    tot_loss = tot_loss * 1.0 / len(train_loader.dataset)
    tot_cd_dist = tot_cd_dist * 1.0 / len(train_loader.dataset)
    return tot_loss, tot_cd_dist
 def reconstruction_loss(self, data, reconstructions):
     data_ = data.transpose(2, 1).contiguous()
     reconstructions_ = reconstructions.transpose(2, 1).contiguous()
     dist1, dist2 = NND.nnd(data_, reconstructions_)
     loss = (torch.mean(dist1)) + (torch.mean(dist2))
     return loss 
        reconstructions = self.caps_decoder(latent_capsules)
        return  reconstructions

if __name__ == '__main__':
    USE_CUDA = True
    batch_size=2 #ORIGINAL IS 8
    
    prim_caps_size=1024
    prim_vec_size=16
    
    latent_caps_size=32
    latent_vec_size=16
    
    num_points=2048

    point_caps_ae = PointCapsNet(prim_caps_size,prim_vec_size,latent_caps_size,latent_vec_size,num_points)
    point_caps_ae=torch.nn.DataParallel(point_caps_ae).cuda()
    
    rand_data=torch.rand(batch_size,num_points, 3) 
    rand_data = Variable(rand_data)
    rand_data = rand_data.transpose(2, 1)
    rand_data=rand_data.cuda()
    
    codewords,reconstruction=point_caps_ae(rand_data)
   
    rand_data_ = rand_data.transpose(2, 1).contiguous()
    reconstruction_ = reconstruction.transpose(2, 1).contiguous()

    dist1, dist2 = NND.nnd(rand_data_, reconstruction_)
    loss = (torch.mean(dist1)) + (torch.mean(dist2))
    print(loss.item())
import torch
import torch.nn as nn
from torch.autograd import Variable

#from modules.nnd import NNDModule
import torch_nndistance as NND

#dist =  NNDModule()

p1 = torch.rand(10, 1000, 3)
p2 = torch.rand(10, 1500, 3)
points1 = Variable(p1, requires_grad=True)
points2 = Variable(p2)
points1 = points1.cuda()
points2 = points2.cuda()
dist1, dist2 = NND.nnd(points1, points2)
print(dist1, dist2)
loss = torch.sum(dist1)
print(loss)
loss.backward()
print(points1.grad, points2.grad)

points1 = Variable(p1.cuda(), requires_grad=True)
points2 = Variable(p2.cuda())
dist1, dist2 = NND.nnd(points1, points2)
print(dist1, dist2)
loss = torch.sum(dist1)
print(loss)
loss.backward()
print(points1.grad, points2.grad)
Exemple #6
0
def main_worker():
    opt, io, tb = get_args()
    start_epoch = -1
    start_time = time.time()
    BASE_DIR = os.path.dirname(
        os.path.abspath(__file__))  # python script folder
    ckt = None
    if len(opt.restart_from) > 0:
        ckt = torch.load(opt.restart_from)
        start_epoch = ckt['epoch'] - 1

    # load configuration from file
    try:
        with open(opt.config) as cf:
            config = json.load(cf)
    except IOError as error:
        print(error)

    # backup relevant files
    shutil.copy(src=os.path.abspath(__file__),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=opt.config,
                dst=os.path.join(opt.save_dir, 'backup_code',
                                 'config.json.backup'))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    torch.cuda.manual_seed_all(opt.manualSeed)

    io.cprint(f"Arguments: {str(opt)}")
    io.cprint(f"Configuration: {str(config)}")

    pnum = config['completion_trainer']['num_points']
    class_choice = opt.class_choice
    # datasets + loaders
    if len(class_choice) > 0:
        class_choice = ''.join(opt.class_choice.split()).split(
            ",")  # sanitize + split(",")
        io.cprint("Class choice list: {}".format(str(class_choice)))
    else:
        class_choice = None  # Train on all classes! (if opt.class_choice=='')

    tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='train')

    te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='test')

    tr_loader = torch.utils.data.DataLoader(tr_dataset,
                                            batch_size=opt.batch_size,
                                            shuffle=True,
                                            num_workers=opt.workers,
                                            drop_last=True)

    te_loader = torch.utils.data.DataLoader(te_dataset,
                                            batch_size=64,
                                            shuffle=True,
                                            num_workers=opt.workers)

    num_holes = int(opt.num_holes)
    crop_point_num = int(opt.crop_point_num)
    context_point_num = int(opt.context_point_num)
    # io.cprint("Num holes: {}".format(num_holes))
    # io.cprint("Crop points num: {}".format(crop_point_num))
    # io.cprint("Context points num: {}".format(context_point_num))
    # io.cprint("Pool1 num points selected: {}".format(opt.pool1_points))
    # io.cprint("Pool2 num points selected: {}".format(opt.pool2_points))
    """" Models """
    gl_encoder = Encoder(conf=config)
    generator = Generator(conf=config,
                          pool1_points=int(opt.pool1_points),
                          pool2_points=int(opt.pool2_points))
    gl_encoder.apply(weights_init_normal)  # affecting only non pretrained
    generator.apply(weights_init_normal)  # not pretrained
    print("Encoder: ", gl_encoder)
    print("Generator: ", generator)

    if ckt is not None:
        io.cprint(f"Restart Training from epoch {start_epoch}.")
        gl_encoder.load_state_dict(ckt['gl_encoder_state_dict'])
        generator.load_state_dict(ckt['generator_state_dict'])
    else:
        io.cprint("Training Completion Task...")
        local_fe_fn = config['completion_trainer']['checkpoint_local_enco']
        global_fe_fn = config['completion_trainer']['checkpoint_global_enco']

        if len(local_fe_fn) > 0:
            local_enco_dict = torch.load(local_fe_fn)['model_state_dict']
            # refactoring pretext-trained local dgcnn encoder state dict keys
            local_enco_dict = remove_prefix_dict(
                state_dict=local_enco_dict, to_remove_str='local_encoder.')
            loc_load_result = gl_encoder.local_encoder.load_state_dict(
                local_enco_dict, strict=False)
            io.cprint(
                f"Local FE pretrained weights - loading res: {str(loc_load_result)}"
            )
        else:
            # Ablation experiments only
            io.cprint("Local FE pretrained weights - NOT loaded", color='r')

        if len(global_fe_fn) > 0:
            global_enco_dict = torch.load(global_fe_fn, )['global_encoder']
            glob_load_result = gl_encoder.global_encoder.load_state_dict(
                global_enco_dict, strict=True)
            io.cprint(
                f"Global FE pretrained weights - loading res: {str(glob_load_result)}",
                color='b')
        else:
            # Ablation experiments only
            io.cprint("Global FE pretrained weights - NOT loaded", color='r')

    io.cprint("Num GPUs: " + str(torch.cuda.device_count()) +
              ", Parallelism: {}".format(opt.parallel))
    if opt.parallel and torch.cuda.device_count() > 1:
        gl_encoder = torch.nn.DataParallel(gl_encoder)
        generator = torch.nn.DataParallel(generator)
    gl_encoder.to(device)
    generator.to(device)

    # Optimizers + schedulers
    opt_E = torch.optim.Adam(
        gl_encoder.parameters(),
        lr=config['completion_trainer']['enco_lr'],  # def: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_E = torch.optim.lr_scheduler.StepLR(
        opt_E,
        step_size=config['completion_trainer']['enco_step'],  # def: 25
        gamma=0.5)

    opt_G = torch.optim.Adam(
        generator.parameters(),
        lr=config['completion_trainer']['gen_lr'],  # def: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_G = torch.optim.lr_scheduler.StepLR(
        opt_G,
        step_size=config['completion_trainer']['gen_step'],  # def: 40
        gamma=0.5)

    if ckt is not None:
        opt_E.load_state_dict(ckt['optimizerE_state_dict'])
        opt_G.load_state_dict(ckt['optimizerG_state_dict'])
        sched_E.load_state_dict(ckt['schedulerE_state_dict'])
        sched_G.load_state_dict(ckt['schedulerG_state_dict'])

    if not opt.fps_centroids:
        # 5 viewpoints to crop around - same as in PFNet
        centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0],
                                [-1, 1, 0]])
    else:
        raise NotImplementedError('experimental')
        centroids = None

    io.cprint("Training.. \n")
    best_test = sys.float_info.max
    best_ep = -1
    it = 0  # global iteration counter
    vis_folder = None
    for epoch in range(start_epoch + 1, opt.epochs):
        start_ep_time = time.time()
        count = 0.0
        tot_loss = 0.0
        tot_fine_loss = 0.0
        tot_raw_loss = 0.0
        gl_encoder = gl_encoder.train()
        generator = generator.train()
        for i, data in enumerate(tr_loader, 0):
            it += 1
            points, _ = data
            B, N, dim = points.size()
            count += B

            partials = []
            fine_gts, raw_gts = [], []
            N_partial_points = N - (crop_point_num * num_holes)
            for m in range(B):
                # points[m]: complete shape of size (N,3)
                # partial: partial point cloud to complete
                # fine_gt: missing part ground truth
                # raw_gt: missing part ground truth + frame points (where frame points are points included in partial)
                partial, fine_gt, raw_gt = crop_shape(points[m],
                                                      centroids=centroids,
                                                      scales=[
                                                          crop_point_num,
                                                          (crop_point_num +
                                                           context_point_num)
                                                      ],
                                                      n_c=num_holes)

                if partial.size(0) > N_partial_points:
                    assert num_holes > 1, "Should be no need to resample if not multiple holes case"
                    # sampling without replacement
                    choice = torch.randperm(partial.size(0))[:N_partial_points]
                    partial = partial[choice]

                partials.append(partial)
                fine_gts.append(fine_gt)
                raw_gts.append(raw_gt)

            if i == 1 and epoch % opt.it_test == 0:
                # make some visualization
                vis_folder = os.path.join(opt.vis_dir,
                                          "epoch_{}".format(epoch))
                safe_make_dirs([vis_folder])
                print(f"ep {epoch} - Saving visualizations into: {vis_folder}")
                for j in range(len(partials)):
                    np.savetxt(X=partials[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_cropped.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=fine_gts[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_fine_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=raw_gts[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_raw_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

            partials = torch.stack(partials).to(device).permute(
                0, 2, 1)  # [B, 3, N-512]
            fine_gts = torch.stack(fine_gts).to(device)  # [B, 512, 3]
            raw_gts = torch.stack(raw_gts).to(device)  # [B, 512 + context, 3]

            if i == 1:  # sanity check
                print("[dbg]: partials: ", partials.size(), ' ',
                      partials.device)
                print("[dbg]: fine grained gts: ", fine_gts.size(), ' ',
                      fine_gts.device)
                print("[dbg]: raw grained gts: ", raw_gts.size(), ' ',
                      raw_gts.device)

            gl_encoder.zero_grad()
            generator.zero_grad()
            feat = gl_encoder(partials)
            fake_fine, fake_raw = generator(
                feat
            )  # pred_fine (only missing part), pred_intermediate (missing + frame)

            # pytorch 1.2 compiled Chamfer (C2C) dist.
            assert fake_fine.size() == fine_gts.size(
            ), "Wrong input shapes to Chamfer module"
            if i == 0:
                if fake_raw.size() != raw_gts.size():
                    warnings.warn(
                        "size dismatch for: raw_pred: {}, raw_gt: {}".format(
                            str(fake_raw.size()), str(raw_gts.size())))

            # fine grained prediction + gt
            fake_fine = fake_fine.contiguous()
            fine_gts = fine_gts.contiguous()
            # raw prediction + gt
            fake_raw = fake_raw.contiguous()
            raw_gts = raw_gts.contiguous()

            dist1, dist2, _, _ = NND.nnd(
                fake_fine, fine_gts)  # fine grained loss computation
            dist1_raw, dist2_raw, _, _ = NND.nnd(
                fake_raw, raw_gts)  # raw grained loss computation

            # standard C2C distance loss
            fine_loss = 100 * (0.5 * torch.mean(dist1) +
                               0.5 * torch.mean(dist2))

            # raw loss: missing part + frame
            raw_loss = 100 * (0.5 * torch.mean(dist1_raw) +
                              0.5 * torch.mean(dist2_raw))

            loss = fine_loss + opt.raw_weight * raw_loss  # missing part pred loss + α * raw reconstruction loss
            loss.backward()
            opt_E.step()
            opt_G.step()
            tot_loss += loss.item() * B
            tot_fine_loss += fine_loss.item() * B
            tot_raw_loss += raw_loss.item() * B

            if it % 10 == 0:
                io.cprint(
                    '[%d/%d][%d/%d]: loss: %.4f, fine CD: %.4f, interm. CD: %.4f'
                    % (epoch, opt.epochs, i, len(tr_loader), loss.item(),
                       fine_loss.item(), raw_loss.item()))

            # make visualizations
            if i == 1 and epoch % opt.it_test == 0:
                assert (vis_folder is not None and os.path.exists(vis_folder))
                fake_fine = fake_fine.cpu().detach().data.numpy()
                fake_raw = fake_raw.cpu().detach().data.numpy()
                for j in range(len(fake_fine)):
                    np.savetxt(X=fake_fine[j],
                               fname=os.path.join(
                                   vis_folder, '{}_pred_fine.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=fake_raw[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_pred_raw.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

        sched_E.step()
        sched_G.step()
        io.cprint(
            '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' %
            (epoch, opt.epochs, tot_loss * 1.0 / count,
             tot_fine_loss * 1.0 / count, tot_raw_loss * 1.0 / count))
        tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_interm', tot_raw_loss * 1.0 / count, epoch)

        if epoch % opt.it_test == 0:
            torch.save(
                {
                    'type_exp':
                    'dgccn at local encoder',
                    'epoch':
                    epoch + 1,
                    'epoch_train_loss':
                    tot_loss * 1.0 / count,
                    'epoch_train_loss_raw':
                    tot_raw_loss * 1.0 / count,
                    'epoch_train_loss_fine':
                    tot_fine_loss * 1.0 / count,
                    'gl_encoder_state_dict':
                    gl_encoder.module.state_dict() if isinstance(
                        gl_encoder,
                        nn.DataParallel) else gl_encoder.state_dict(),
                    'generator_state_dict':
                    generator.module.state_dict() if isinstance(
                        generator, nn.DataParallel) else
                    generator.state_dict(),
                    'optimizerE_state_dict':
                    opt_E.state_dict(),
                    'optimizerG_state_dict':
                    opt_G.state_dict(),
                    'schedulerE_state_dict':
                    sched_E.state_dict(),
                    'schedulerG_state_dict':
                    sched_G.state_dict(),
                },
                os.path.join(opt.models_dir,
                             'checkpoint_' + str(epoch) + '.pth'))

        if epoch % opt.it_test == 0:
            test_cd, count = 0.0, 0.0
            for i, data in enumerate(te_loader, 0):
                points, _ = data
                B, N, dim = points.size()
                count += B

                partials = []
                fine_gts = []
                N_partial_points = N - (crop_point_num * num_holes)

                for m in range(B):
                    partial, fine_gt, _ = crop_shape(points[m],
                                                     centroids=centroids,
                                                     scales=[
                                                         crop_point_num,
                                                         (crop_point_num +
                                                          context_point_num)
                                                     ],
                                                     n_c=num_holes)

                    if partial.size(0) > N_partial_points:
                        assert num_holes > 1
                        # sampling Without replacement
                        choice = torch.randperm(
                            partial.size(0))[:N_partial_points]
                        partial = partial[choice]

                    partials.append(partial)
                    fine_gts.append(fine_gt)
                partials = torch.stack(partials).to(device).permute(
                    0, 2, 1)  # [B, 3, N-512]
                fine_gts = torch.stack(fine_gts).to(
                    device).contiguous()  # [B, 512, 3]

                # TEST FORWARD
                # Considering only missing part prediction at Test Time
                gl_encoder.eval()
                generator.eval()
                with torch.no_grad():
                    feat = gl_encoder(partials)
                    fake_fine, _ = generator(feat)

                fake_fine = fake_fine.contiguous()
                assert fake_fine.size() == fine_gts.size()
                dist1, dist2, _, _ = NND.nnd(fake_fine, fine_gts)
                cd_loss = 100 * (0.5 * torch.mean(dist1) +
                                 0.5 * torch.mean(dist2))
                test_cd += cd_loss.item() * B

            test_cd = test_cd * 1.0 / count
            io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' %
                      (epoch, opt.epochs, test_cd),
                      color="b")
            tb.add_scalar('Test/cd_loss', test_cd, epoch)
            is_best = test_cd < best_test
            best_test = min(best_test, test_cd)

            if is_best:
                # best model case
                best_ep = epoch
                io.cprint("New best test %.5f at epoch %d" %
                          (best_test, best_ep))
                shutil.copyfile(src=os.path.join(
                    opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'),
                                dst=os.path.join(opt.models_dir,
                                                 'best_model.pth'))
        io.cprint(
            '[%d/%d] Epoch time: %s' %
            (epoch, num_epochs,
             time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time))))

    # Script ends
    hours, rem = divmod(time.time() - start_time, 3600)
    minutes, seconds = divmod(rem, 60)
    io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
Exemple #7
0
import torch
import torch_nndistance as NND

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
points1 = torch.rand(16, 2048, 3).to(device)
points1.requires_grad = True
points2 = torch.rand(16, 1024, 3).to(device)
points2.requires_grad = True

dist1, dist2, idx1, idx2 = NND.nnd(points1, points2)
# print(dist1, dist2)
print('dist1: ', dist1.size())
print('dist2: ', dist2.size())

loss = torch.sum(dist1)
print(loss)
loss.backward()
print(points1.grad, points2.grad)
print('ok - Test 1')
print("points1 grad:\n", points1.grad)
print("")
print("points2 grad:\n", points2.grad)

if points1.grad is not None and points2.grad is not None:
    print('ok - Test 2 Gradient')
else:
    print('Fail - Test 2 Gradient')
Exemple #8
0
            #######################################################
            # (3) Update G network: maximize log(D(G(z)))
            #######################################################
            point_netG.zero_grad()
            label.data.fill_(real_label)  # foolish
            output = point_netD(fake)
            errG_D = criterion(output, label)
            errG_l2 = 0

            fake = fake.squeeze(
                1).contiguous()  # [32, 1, 512, 3] -> [32, 512, 3]
            real_center = real_center.squeeze(1).contiguous()
            # print("dbg fake: ", fake.size())
            # print("dbg real_center: ", real_center.size())
            assert fake.size() == real_center.size(), "fail fake shape"
            d1, d2, _, _ = NND.nnd(fake, real_center)
            CD_LOSS = 100 * (0.5 * torch.mean(d1) + 0.5 * torch.mean(d2))

            # computing also errG_l2
            ''' fake center 1 '''
            fake_center1 = fake_center1.contiguous()
            real_center_key1 = real_center_key1.contiguous()
            # print("dbg fake_center1: ", fake_center1.size())
            # print("dbg real_center_key1: ", real_center_key1.size())
            assert fake_center1.size() == real_center_key1.size(
            ), "fail fake 1 {}".format(str(fake_center1.size()))
            d1, d2, _, _ = NND.nnd(fake_center1, real_center_key1)
            cd_fake_1 = 100 * (0.5 * torch.mean(d1) + 0.5 * torch.mean(d2))
            ''' fake center 2 '''
            fake_center2 = fake_center2.contiguous()
            real_center_key2 = real_center_key2.contiguous()
Exemple #9
0
def main_worker():
    opt, io, tb = get_args()
    start_epoch = -1
    start_time = time.time()
    ckt = None
    if len(opt.restart_from) > 0:
        ckt = torch.load(opt.restart_from)
        start_epoch = ckt['epoch'] - 1

    # load configuration from file
    try:
        with open(opt.config) as cf:
            config = json.load(cf)
    except IOError as error:
        print(error)

    # backup relevant files
    shutil.copy(src=os.path.abspath(__file__),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'),
                dst=os.path.join(opt.save_dir, 'backup_code'))
    shutil.copy(src=opt.config,
                dst=os.path.join(opt.save_dir, 'backup_code',
                                 'config.json.backup'))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    torch.cuda.manual_seed_all(opt.manualSeed)

    io.cprint(f"Arguments: {str(opt)}")
    io.cprint(f"Configuration: {str(config)}")

    pnum = config['completion_trainer'][
        'num_points']  # number of points of complete pointcloud
    class_choice = opt.class_choice  # config['completion_trainer']['class_choice']

    # datasets + loaders
    if len(class_choice) > 0:
        class_choice = ''.join(opt.class_choice.split()).split(
            ",")  # sanitize + split(",")
        io.cprint("Class choice list: {}".format(str(class_choice)))
    else:
        class_choice = None  # training on all snpart classes

    tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='train')

    te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root,
                                                  classification=True,
                                                  class_choice=class_choice,
                                                  npoints=pnum,
                                                  split='test')

    tr_loader = torch.utils.data.DataLoader(tr_dataset,
                                            batch_size=opt.batch_size,
                                            shuffle=True,
                                            num_workers=opt.workers,
                                            drop_last=True)

    te_loader = torch.utils.data.DataLoader(te_dataset,
                                            batch_size=64,
                                            shuffle=True,
                                            num_workers=opt.workers)

    num_holes = int(opt.num_holes)
    crop_point_num = int(opt.crop_point_num)
    context_point_num = int(opt.context_point_num)

    # io.cprint(f"Completion Setting:\n num classes {len(tr_dataset.cat.keys())}, num holes: {num_holes}, "
    #           f"crop point num: {crop_point_num}, frame/context point num: {context_point_num},\n"
    #           f"num points at pool1: {opt.pool1_points}, num points at pool2: {opt.pool2_points} ")

    # Models
    gl_encoder = Encoder(conf=config)
    generator = Generator(conf=config,
                          pool1_points=int(opt.pool1_points),
                          pool2_points=int(opt.pool2_points))
    gl_encoder.apply(
        weights_init_normal)  # affecting only non pretrained layers
    generator.apply(weights_init_normal)

    print("Encoder: ", gl_encoder)
    print("Generator: ", generator)

    if ckt is not None:
        # resuming training from intermediate checkpoint
        # restoring both encoder and generator state
        io.cprint(f"Restart Training from epoch {start_epoch}.")
        gl_encoder.load_state_dict(ckt['gl_encoder_state_dict'])
        generator.load_state_dict(ckt['generator_state_dict'])
        io.cprint("Whole model loaded from {}\n".format(opt.restart_from))
    else:
        # training the completion model
        # load local and global encoder pretrained (ssl pretexts) weights
        io.cprint("Training Completion Task...")
        local_fe_fn = config['completion_trainer']['checkpoint_local_enco']
        global_fe_fn = config['completion_trainer']['checkpoint_global_enco']

        if len(local_fe_fn) > 0:
            local_enco_dict = torch.load(local_fe_fn, )['model_state_dict']
            loc_load_result = gl_encoder.local_encoder.load_state_dict(
                local_enco_dict, strict=False)
            io.cprint(
                f"Local FE pretrained weights - loading res: {str(loc_load_result)}"
            )
        else:
            # Ablation experiments only
            io.cprint("Local FE pretrained weights - NOT loaded", color='r')

        if len(global_fe_fn) > 0:
            global_enco_dict = torch.load(global_fe_fn, )['global_encoder']
            glob_load_result = gl_encoder.global_encoder.load_state_dict(
                global_enco_dict, strict=True)
            io.cprint(
                f"Global FE pretrained weights - loading res: {str(glob_load_result)}",
                color='b')
        else:
            # Ablation experiments only
            io.cprint("Global FE pretrained weights - NOT loaded", color='r')

    io.cprint("Num GPUs: " + str(torch.cuda.device_count()) +
              ", Parallelism: {}".format(opt.parallel))
    if opt.parallel:
        # TODO: implement DistributedDataParallel training
        assert torch.cuda.device_count() > 1
        gl_encoder = torch.nn.DataParallel(gl_encoder)
        generator = torch.nn.DataParallel(generator)
    gl_encoder.to(device)
    generator.to(device)

    # Optimizers + schedulers
    opt_E = torch.optim.Adam(
        gl_encoder.parameters(),
        lr=config['completion_trainer']['enco_lr'],  # default is: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_E = torch.optim.lr_scheduler.StepLR(
        opt_E,
        step_size=config['completion_trainer']['enco_step'],  # default is: 25
        gamma=0.5)

    opt_G = torch.optim.Adam(
        generator.parameters(),
        lr=config['completion_trainer']['gen_lr'],  # default is: 10e-4
        betas=(0.9, 0.999),
        eps=1e-05,
        weight_decay=0.001)

    sched_G = torch.optim.lr_scheduler.StepLR(
        opt_G,
        step_size=config['completion_trainer']['gen_step'],  # default is: 40
        gamma=0.5)

    if ckt is not None:
        # resuming training from intermediate checkpoint
        # restore optimizers state
        opt_E.load_state_dict(ckt['optimizerE_state_dict'])
        opt_G.load_state_dict(ckt['optimizerG_state_dict'])
        sched_E.load_state_dict(ckt['schedulerE_state_dict'])
        sched_G.load_state_dict(ckt['schedulerG_state_dict'])

    # crop centroids
    if not opt.fps_centroids:
        # 5 viewpoints to crop around - same crop procedure of PFNet - main paper
        centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0],
                                [-1, 1, 0]])
    else:
        raise NotImplementedError('experimental')
        centroids = None

    io.cprint('Centroids: ' + str(centroids))

    # training loop
    io.cprint("Training.. \n")
    best_test = sys.float_info.max
    best_ep, glob_it = -1, 0
    vis_folder = None
    for epoch in range(start_epoch + 1, opt.epochs):
        start_ep_time = time.time()
        count = 0.0
        tot_loss = 0.0
        tot_fine_loss = 0.0
        tot_interm_loss = 0.0
        gl_encoder = gl_encoder.train()
        generator = generator.train()
        for i, data in enumerate(tr_loader, 0):
            glob_it += 1
            points, _ = data
            B, N, dim = points.size()
            count += B

            partials = []
            fine_gts, interm_gts = [], []
            N_partial_points = N - (crop_point_num * num_holes)
            for m in range(B):
                partial, fine_gt, interm_gt = crop_shape(
                    points[m],
                    centroids=centroids,
                    scales=[
                        crop_point_num, (crop_point_num + context_point_num)
                    ],
                    n_c=num_holes)

                if partial.size(0) > N_partial_points:
                    assert num_holes > 1
                    # sampling without replacement
                    choice = torch.randperm(partial.size(0))[:N_partial_points]
                    partial = partial[choice]

                partials.append(partial)
                fine_gts.append(fine_gt)
                interm_gts.append(interm_gt)

            if i == 1 and epoch % opt.it_test == 0:
                # make some visualization
                vis_folder = os.path.join(opt.vis_dir,
                                          "epoch_{}".format(epoch))
                safe_make_dirs([vis_folder])
                print(f"ep {epoch} - Saving visualizations into: {vis_folder}")
                for j in range(len(partials)):
                    np.savetxt(X=partials[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_partial.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=fine_gts[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_fine_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=interm_gts[j],
                               fname=os.path.join(
                                   vis_folder, '{}_interm_gt.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

            partials = torch.stack(partials).to(device).permute(
                0, 2, 1)  # [B, 3, N-512]
            fine_gts = torch.stack(fine_gts).to(device)  # [B, 512, 3]
            interm_gts = torch.stack(interm_gts).to(device)  # [B, 1024, 3]

            gl_encoder.zero_grad()
            generator.zero_grad()
            feat = gl_encoder(partials)
            pred_fine, pred_raw = generator(feat)

            # pytorch 1.2 compiled Chamfer (C2C) dist.
            assert pred_fine.size() == fine_gts.size()
            pred_fine, pred_raw = pred_fine.contiguous(), pred_raw.contiguous()
            fine_gts, interm_gts = fine_gts.contiguous(
            ), interm_gts.contiguous()

            dist1, dist2, _, _ = NND.nnd(pred_fine,
                                         fine_gts)  # missing part pred loss
            dist1_raw, dist2_raw, _, _ = NND.nnd(
                pred_raw, interm_gts)  # intermediate pred loss
            fine_loss = 50 * (torch.mean(dist1) + torch.mean(dist2)
                              )  # chamfer is weighted by 100
            interm_loss = 50 * (torch.mean(dist1_raw) + torch.mean(dist2_raw))

            loss = fine_loss + opt.raw_weight * interm_loss
            loss.backward()
            opt_E.step()
            opt_G.step()
            tot_loss += loss.item() * B
            tot_fine_loss += fine_loss.item() * B
            tot_interm_loss += interm_loss.item() * B

            if glob_it % 10 == 0:
                header = "[%d/%d][%d/%d]" % (epoch, opt.epochs, i,
                                             len(tr_loader))
                io.cprint('%s: loss: %.4f, fine CD: %.4f, interm. CD: %.4f' %
                          (header, loss.item(), fine_loss.item(),
                           interm_loss.item()))

            # make visualizations
            if i == 1 and epoch % opt.it_test == 0:
                assert (vis_folder is not None and os.path.exists(vis_folder))
                pred_fine = pred_fine.cpu().detach().data.numpy()
                pred_raw = pred_raw.cpu().detach().data.numpy()
                for j in range(len(pred_fine)):
                    np.savetxt(X=pred_fine[j],
                               fname=os.path.join(
                                   vis_folder, '{}_pred_fine.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')
                    np.savetxt(X=pred_raw[j],
                               fname=os.path.join(vis_folder,
                                                  '{}_pred_raw.txt'.format(j)),
                               fmt='%.5f',
                               delimiter=';')

        sched_E.step()
        sched_G.step()
        io.cprint(
            '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' %
            (epoch, opt.epochs, tot_loss * 1.0 / count,
             tot_fine_loss * 1.0 / count, tot_interm_loss * 1.0 / count))
        tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch)
        tb.add_scalar('Train/cd_interm', tot_interm_loss * 1.0 / count, epoch)

        if epoch % opt.it_test == 0:
            torch.save(
                {
                    'epoch':
                    epoch + 1,
                    'epoch_train_loss':
                    tot_loss * 1.0 / count,
                    'epoch_train_loss_raw':
                    tot_interm_loss * 1.0 / count,
                    'epoch_train_loss_fine':
                    tot_fine_loss * 1.0 / count,
                    'gl_encoder_state_dict':
                    gl_encoder.module.state_dict() if isinstance(
                        gl_encoder,
                        nn.DataParallel) else gl_encoder.state_dict(),
                    'generator_state_dict':
                    generator.module.state_dict() if isinstance(
                        generator, nn.DataParallel) else
                    generator.state_dict(),
                    'optimizerE_state_dict':
                    opt_E.state_dict(),
                    'optimizerG_state_dict':
                    opt_G.state_dict(),
                    'schedulerE_state_dict':
                    sched_E.state_dict(),
                    'schedulerG_state_dict':
                    sched_G.state_dict(),
                },
                os.path.join(opt.models_dir,
                             'checkpoint_' + str(epoch) + '.pth'))

        if epoch % opt.it_test == 0:
            test_cd, count = 0.0, 0.0
            for i, data in enumerate(te_loader, 0):
                points, _ = data
                B, N, dim = points.size()
                count += B

                partials = []
                fine_gts = []
                N_partial_points = N - (crop_point_num * num_holes)

                for m in range(B):
                    partial, fine_gt, _ = crop_shape(points[m],
                                                     centroids=centroids,
                                                     scales=[
                                                         crop_point_num,
                                                         (crop_point_num +
                                                          context_point_num)
                                                     ],
                                                     n_c=num_holes)

                    if partial.size(0) > N_partial_points:
                        assert num_holes > 1
                        # sampling Without replacement
                        choice = torch.randperm(
                            partial.size(0))[:N_partial_points]
                        partial = partial[choice]

                    partials.append(partial)
                    fine_gts.append(fine_gt)
                partials = torch.stack(partials).to(device).permute(
                    0, 2, 1)  # [B, 3, N-512]
                fine_gts = torch.stack(fine_gts).to(
                    device).contiguous()  # [B, 512, 3]

                # TEST FORWARD
                # Considering only missing part prediction at Test Time
                gl_encoder.eval()
                generator.eval()
                with torch.no_grad():
                    feat = gl_encoder(partials)
                    pred_fine, _ = generator(feat)

                pred_fine = pred_fine.contiguous()
                assert pred_fine.size() == fine_gts.size()
                dist1, dist2, _, _ = NND.nnd(pred_fine, fine_gts)
                cd_loss = 50 * (torch.mean(dist1) + torch.mean(dist2))
                test_cd += cd_loss.item() * B

            test_cd = test_cd * 1.0 / count
            io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' %
                      (epoch, opt.epochs, test_cd),
                      color="b")
            tb.add_scalar('Test/cd_loss', test_cd, epoch)
            is_best = test_cd < best_test
            best_test = min(best_test, test_cd)

            if is_best:
                # best model case
                best_ep = epoch
                io.cprint("New best test %.5f at epoch %d" %
                          (best_test, best_ep))
                shutil.copyfile(src=os.path.join(
                    opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'),
                                dst=os.path.join(opt.models_dir,
                                                 'best_model.pth'))
        io.cprint(
            '[%d/%d] Epoch time: %s' %
            (epoch, opt.epochs,
             time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time))))

    # Script ends
    hours, rem = divmod(time.time() - start_time, 3600)
    minutes, seconds = divmod(rem, 60)
    io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format(
        int(hours), int(minutes), seconds))
    io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
Exemple #10
0
        frag2_batch = frag2_batch.squeeze().cuda()
        R1 = R1.squeeze().cuda()
        R2 = R2.squeeze().cuda()
        lrf1 = lrf1.squeeze().cuda()
        lrf2 = lrf2.squeeze().cuda()

        optimizer.zero_grad()

        f1, xtrans1, trans1, f2, xtrans2, trans2 = net(frag1_batch,
                                                       frag2_batch)

        # hardest-contrastive loss
        lcontrastive, a, b, c = hardest_contrastive(f1, f2)
        # chamfer loss
        dist1, dist2 = NND.nnd(
            xtrans1.transpose(2, 1).contiguous(),
            xtrans2.transpose(2, 1).contiguous())
        lchamf = .5 * (torch.mean(dist1) + torch.mean(dist2))
        # combination of losses
        loss = lcontrastive + lchamf

        loss.backward()
        optimizer.step()

        writer.add_scalar('loss/train', loss.item(), n_iter)
        writer.add_scalar('hardest_contrastive/positive - train',
                          torch.mean(a).item(), n_iter)
        writer.add_scalar('hardest_contrastive/negative1 - train',
                          torch.mean(b[0]).item(), n_iter)
        writer.add_scalar('hardest_contrastive/negative2 - train',
                          torch.mean(c[0]).item(), n_iter)