Example #1
0
def train(model, train_loader, valid_loader, config):
    model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=config.train_lr, weight_decay = config.weight_decay)
    match_loss = MatchLoss(config)

    checkpoint_path = os.path.join(config.log_path, 'checkpoint.pth')
    config.resume = os.path.isfile(checkpoint_path)
    if config.resume:
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(checkpoint_path)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger_train = Logger(os.path.join(config.log_path, 'log_train.txt'), title='oan', resume=True)
        logger_valid = Logger(os.path.join(config.log_path, 'log_valid.txt'), title='oan', resume=True)
    else:
        best_acc = -1
        start_epoch = 0
        logger_train = Logger(os.path.join(config.log_path, 'log_train.txt'), title='oan')
        logger_train.set_names(['Learning Rate'] + ['Geo Loss', 'Classfi Loss', 'L2 Loss']*(config.iter_num+1))
        logger_valid = Logger(os.path.join(config.log_path, 'log_valid.txt'), title='oan')
        logger_valid.set_names(['Valid Acc'] + ['Geo Loss', 'Clasfi Loss', 'L2 Loss'])
    train_loader_iter = iter(train_loader)
    for step in trange(start_epoch, config.train_iter, ncols=config.tqdm_width):
        try:
            train_data = next(train_loader_iter)
        except StopIteration:
            train_loader_iter = iter(train_loader)
            train_data = next(train_loader_iter)
        train_data = tocuda(train_data)

        # run training
        cur_lr = optimizer.param_groups[0]['lr']
        loss_vals = train_step(step, optimizer, model, match_loss, train_data)
        logger_train.append([cur_lr] + loss_vals)

        # Check if we want to write validation
        b_save = ((step + 1) % config.save_intv) == 0
        b_validate = ((step + 1) % config.val_intv) == 0
        if b_validate:
            va_res, geo_loss, cla_loss, l2_loss,  _, _, _  = valid(valid_loader, model, step, config)
            logger_valid.append([va_res, geo_loss, cla_loss, l2_loss])
            if va_res > best_acc:
                print("Saving best model with va_res = {}".format(va_res))
                best_acc = va_res
                torch.save({
                'epoch': step + 1,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
                }, os.path.join(config.log_path, 'model_best.pth'))

        if b_save:
            torch.save({
            'epoch': step + 1,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
            }, checkpoint_path)
Example #2
0
def test(loader, net, epoch):
    N = int(np.ceil(len(loader.dataset) / loader.batch_size))
    net.eval()
    dist = []
    labs = []
    for i, batch in enumerate(loader):
        img1, img2 = utils.tocuda(batch[:-1])
        label = batch[-1]

        with torch.no_grad():
            img1_embed = net(img1)
            img2_embed = net(img2)

            d = torch.pow(img1_embed - img2_embed, 2).sum(1)
            dist += d.cpu().tolist()
            labs += label.cpu().tolist()

        print(f'\rEPOCH {epoch}: test batch {i:04d}/{N}{" "*10}',
              end='',
              flush=True)

    dist = np.array(dist)
    labs = np.array(labs)
    meds = np.median(dist.reshape(2, -1), axis=1)
    threshold = meds[0] + (meds[1] - meds[0]) / 2
    thresh = dist > threshold
    accuracy = np.mean(thresh == labs)
    return accuracy
Example #3
0
def pair_distance(net, dataloader):
    net.eval()
    dist = []
    labs = []
    for batch in tqdm.tqdm(dataloader):
        img1, img2 = utils.tocuda(batch[:-1])
        label = batch[-1]
        with torch.no_grad():
            emb1 = net(img1)
            emb2 = net(img2)

            d = torch.pow(emb1 - emb2, 2).sum(1)
            dist += d.cpu().tolist()
            labs += label.tolist()

    dist = np.array(dist)
    labs = np.array(labs)
    return dist, labs
Example #4
0
def train(loader, net, loss_fn, opt, epoch):
    N = int(np.ceil(len(loader.dataset) / loader.batch_size))
    loss = 0.0
    violations = 0
    total = 0
    net.train()
    i = 0
    for i, batch in enumerate(loader):
        opt.zero_grad()
        anchor, pos, neg = utils.tocuda(batch[:-1])
        ind = batch[-1]
        anchor.requires_grad_()
        pos.requires_grad_()
        neg.requires_grad_()

        anc_embed = net(anchor)
        pos_embed = net(pos)
        neg_embed = net(neg)

        #batch_loss = loss_fn(anc_embed, pos_embed, neg_embed)
        batch_loss = loss_fn(anc_embed, pos_embed, neg_embed, ind)
        if batch_loss > 0:
            batch_loss.backward()
            opt.step()
            loss += batch_loss.item()

    #v = (torch.pow(anc_embed - pos_embed, 2).sum(1) + 0.2 >
    #    torch.pow(anc_embed - neg_embed, 2).sum(1))
    #violations += v.sum().item()
    #total += anchor.size()[0]
        print(f'\rEPOCH {epoch}: train batch {i:04d}/{N}{" "*10}',
              end='',
              flush=True)
    loss /= (i + 1)

    opt.zero_grad()
    return loss
Example #5
0
    def forward(self, inputs):
        '''

        :param inputs: dict: {
            'imgs':                    (Tensor), images,
                                    (batch size, number of views, C, H, W)
            'vol_origin':              (Tensor), origin of the full voxel volume (xyz position of voxel (0, 0, 0)),
                                    (batch size, 3)
            'vol_origin_partial':      (Tensor), origin of the partial voxel volume (xyz position of voxel (0, 0, 0)),
                                    (batch size, 3)
            'world_to_aligned_camera': (Tensor), matrices: transform from world coords to aligned camera coords,
                                    (batch size, number of views, 4, 4)
            'proj_matrices':           (Tensor), projection matrix,
                                    (batch size, number of views, number of scales, 4, 4)
            when we have ground truth:
            'tsdf_list':               (List), tsdf ground truth for each level,
                                    [(batch size, DIM_X, DIM_Y, DIM_Z)]
            'occ_list':                (List), occupancy ground truth for each level,
                                    [(batch size, DIM_X, DIM_Y, DIM_Z)]
            others: unused in network
        }
        :return: outputs: dict: {
            'coords':                  (Tensor), coordinates of voxels,
                                    (number of voxels, 4) (4 : batch ind, x, y, z)
            'tsdf':                    (Tensor), TSDF of voxels,
                                    (number of voxels, 1)
            When it comes to save results:
            'origin':                  (List), origin of the predicted partial volume,
                                    [3]
            'scene_tsdf':              (List), predicted tsdf volume,
                                    [(nx, ny, nz)]
        }
                 loss_dict: dict: {
            'tsdf_occ_loss_X':         (Tensor), multi level loss
            'total_loss':              (Tensor), total loss
        }
        '''
        inputs = tocuda(inputs)
        outputs = {}
        imgs = torch.unbind(inputs['imgs'], 1)

        # image feature extraction
        # in: images; out: feature maps
        features = [self.backbone2d(self.normalizer(img)) for img in imgs]

        # coarse-to-fine decoder: SparseConv and GRU Fusion.
        # in: image feature; out: sparse coords and tsdf
        outputs, loss_dict = self.neucon_net(features, inputs, outputs)

        # fuse to global volume.
        if not self.training and 'coords' in outputs.keys():
            outputs = self.fuse_to_global(outputs['coords'], outputs['tsdf'], inputs, self.n_scales, outputs)

        # gather loss.
        print_loss = 'Loss: '
        for k, v in loss_dict.items():
            print_loss += f'{k}: {v} '

        weighted_loss = 0

        for i, (k, v) in enumerate(loss_dict.items()):
            weighted_loss += v * self.cfg.LW[i]

        loss_dict.update({'total_loss': weighted_loss})
        return outputs, loss_dict
Example #6
0
def test_process(mode, model, cur_global_step, data_loader, config):
    model.eval()
    match_loss = MatchLoss(config)
    loader_iter = iter(data_loader)

    # save info given by the network
    network_infor_list = [
        "geo_losses", "cla_losses", "l2_losses", 'precisions', 'recalls',
        'f_scores'
    ]
    network_info = {info: [] for info in network_infor_list}

    results, pool_arg = [], []
    eval_step, eval_step_i, num_processor = 100, 0, 8
    inlier_ratio_mean, inlier_ratio_max, select_inlier_ratio = [], [], []
    with torch.no_grad():
        for test_data in tqdm(loader_iter):
            test_data = tocuda(test_data)
            res_logits, res_e_hat, Ms, stats = model(test_data)
            # inlier_ratio_mean += [stats['inlier_ratio_mean']]
            # inlier_ratio_max += [stats['inlier_ratio_max']]
            # select_inlier_ratio += [stats['select_inlier_ratio']]
            y_hat, e_hat = res_logits[-1], res_e_hat[-1]
            loss, geo_loss, cla_loss, l2_loss, prec, rec = match_loss.run(
                cur_global_step, test_data, y_hat, e_hat, Ms[-1])
            info = [
                geo_loss, cla_loss, l2_loss, prec, rec,
                2 * prec * rec / (prec + rec + 1e-15)
            ]
            for info_idx, value in enumerate(info):
                network_info[network_infor_list[info_idx]].append(value)

            if config.use_fundamental:
                # unnorm F
                e_hat = torch.matmul(
                    torch.matmul(test_data['T2s'].transpose(1, 2),
                                 e_hat.reshape(-1, 3, 3)), test_data['T1s'])
                # get essential matrix from fundamental matrix
                e_hat = torch.matmul(
                    torch.matmul(test_data['K2s'].transpose(1, 2),
                                 e_hat.reshape(-1, 3, 3)),
                    test_data['K1s']).reshape(-1, 9)
                e_hat = e_hat / torch.norm(e_hat, dim=1, keepdim=True)

            for batch_idx in range(e_hat.shape[0]):
                test_xs = test_data['xs'][batch_idx].detach().cpu().numpy()
                if config.use_fundamental:  # back to original
                    x1, x2 = test_xs[0, :, :2], test_xs[0, :, 2:4]
                    T1, T2 = test_data['T1s'][batch_idx].cpu().numpy(
                    ), test_data['T2s'][batch_idx].cpu().numpy()
                    x1, x2 = denorm(x1,
                                    T1), denorm(x2,
                                                T2)  # denormalize coordinate
                    K1, K2 = test_data['K1s'][batch_idx].cpu().numpy(
                    ), test_data['K2s'][batch_idx].cpu().numpy()
                    x1, x2 = denorm(x1, K1), denorm(
                        x2, K2)  # normalize coordiante with intrinsic
                    test_xs = np.concatenate([x1, x2],
                                             axis=-1).reshape(1, -1, 4)

                pool_arg += [(test_xs, test_data['Rs'][batch_idx].detach().cpu().numpy(), \
                              test_data['ts'][batch_idx].detach().cpu().numpy(), e_hat[batch_idx].detach().cpu().numpy(), \
                              y_hat[batch_idx].detach().cpu().numpy(),  \
                              test_data['ys'][batch_idx,:,0].detach().cpu().numpy(), config)]

                eval_step_i += 1
                if eval_step_i % eval_step == 0:
                    results += get_pool_result(num_processor, test_sample,
                                               pool_arg)
                    pool_arg = []
        if len(pool_arg) > 0:
            results += get_pool_result(num_processor, test_sample, pool_arg)
    # print(np.mean(inlier_ratio_mean), np.mean(inlier_ratio_max), np.mean(select_inlier_ratio))
    measure_list = ["err_q", "err_t", "num", 'R_hat', 't_hat']
    eval_res = {}
    for measure_idx, measure in enumerate(measure_list):
        eval_res[measure] = np.asarray(
            [result[measure_idx] for result in results])

    if config.res_path == '':
        config.res_path = os.path.join(config.log_path[:-5], mode)
    tag = "ours" if not config.use_ransac else "ours_ransac"
    ret_val = dump_res(measure_list, config.res_path, eval_res, tag)
    return [ret_val, np.mean(np.asarray(network_info['geo_losses'])), np.mean(np.asarray(network_info['cla_losses'])), \
        np.mean(np.asarray(network_info['l2_losses'])), np.mean(np.asarray(network_info['precisions'])), \
        np.mean(np.asarray(network_info['recalls'])), np.mean(np.asarray(network_info['f_scores']))]