コード例 #1
0
def train(train_loader, model, optimizer, args):
    global device
    model.train()  # switch to train mode
    chamfer_loss = AverageMeter()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(
            data)
        y_pred = data_displacement + data.pos
        loss_chamfer = 0.0
        if args.use_bce:
            mask_gt = data.mask.unsqueeze(1)
            loss_chamfer += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(
                mask_pred_nosigmoid, mask_gt.float(), reduction='mean')
        for i in range(len(torch.unique(data.batch))):
            y_gt_sample = data.y[data.batch == i, :]
            y_gt_sample = y_gt_sample[:data.num_joint[i], :]
            y_pred_sample = y_pred[data.batch == i, :]
            mask_pred_sample = mask_pred[data.batch == i]
            loss_chamfer += chamfer_distance_with_average(
                y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0))
            clustered_pred = meanshift_cluster(y_pred_sample, bandwidth,
                                               mask_pred_sample, args)
            for j in range(args.meanshift_step):
                loss_chamfer += args.ms_loss_weight * chamfer_distance_with_average(
                    clustered_pred[j].unsqueeze(0), y_gt_sample.unsqueeze(0))
        loss_chamfer.backward()
        optimizer.step()
        chamfer_loss.update(loss_chamfer.item(),
                            n=len(torch.unique(data.batch)))
    return chamfer_loss.avg
コード例 #2
0
def train(train_loader, model, optimizer, args, epoch):
    global device
    model.train()  # switch to train mode
    loss_meter = AverageMeter()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        if args.arch == 'masknet':
            mask_pred = model(data)
            mask_gt = data.mask.unsqueeze(1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                mask_pred, mask_gt.float(), reduction='mean')
        elif args.arch == 'jointnet':
            data_displacement = model(data)
            y_pred = data_displacement + data.pos
            loss = 0.0
            for i in range(len(torch.unique(data.batch))):
                y_gt_sample = data.y[data.batch == i, :]
                y_gt_sample = y_gt_sample[:data.num_joint[i], :]
                y_pred_sample = y_pred[data.batch == i, :]
                loss += chamfer_distance_with_average(
                    y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0))
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.item())
    return loss_meter.avg
コード例 #3
0
ファイル: run_joint_finetune.py プロジェクト: zhan-xu/RigNet
def test(test_loader, model, args, save_result=False, best_epoch=None):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    outdir = args.checkpoint.split('/')[-1]
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data)
            y_pred = data_displacement + data.pos
            loss_total = 0.0
            for i in range(len(torch.unique(data.batch))):
                joint_gt = data.joints[data.joints_batch == i, :]
                y_pred_i = y_pred[data.batch == i, :]
                mask_pred_i = mask_pred[data.batch == i]
                loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0))
                clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args)
                loss_ms = 0.0
                for j in range(args.meanshift_step):
                    loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0))
                loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step
                if save_result:
                    output_point_cloud_ply(y_pred_i, name=str(data.name[i].item()),
                                           output_folder='results/{:s}/best_{:d}/'.format(outdir, best_epoch))
                    np.save('results/{:s}/best_{:d}/{:d}_attn.npy'.format(outdir, best_epoch, data.name[i].item()),
                            mask_pred_i.data.to("cpu").numpy())
                    np.save('results/{:s}/best_{:d}/{:d}_bandwidth.npy'.format(outdir, best_epoch, data.name[i].item()),
                            bandwidth.data.to("cpu").numpy())
            loss_total /= len(torch.unique(data.batch))
            if args.use_bce:
                mask_gt = data.mask.unsqueeze(1)
                loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean')
            loss_meter.update(loss_total.item())
    return loss_meter.avg
コード例 #4
0
ファイル: run_joint_finetune.py プロジェクト: zhan-xu/RigNet
def train(train_loader, model, optimizer, args):
    global device
    model.train()  # switch to train mode
    loss_meter = AverageMeter()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data)
        y_pred = data_displacement + data.pos
        loss_total = 0.0
        for i in range(len(torch.unique(data.batch))):
            joint_gt = data.joints[data.joints_batch == i, :]
            y_pred_i = y_pred[data.batch == i, :]
            mask_pred_i = mask_pred[data.batch == i]
            loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0))
            clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args)
            loss_ms = 0.0
            for j in range(args.meanshift_step):
                loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0))
            loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step
        loss_total /= len(torch.unique(data.batch))
        if args.use_bce:
            mask_gt = data.mask.unsqueeze(1)
            loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean')
        loss_total.backward()
        optimizer.step()
        loss_meter.update(loss_total.item())
    return loss_meter.avg
コード例 #5
0
def test(test_loader, model, args):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    acc_total = 0.0
    for data in test_loader:
        #print(data.name)
        data = data.to(device)
        with torch.no_grad():
            pre_label, label = model(data)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                pre_label, label.float())
            loss_meter.update(loss.item(), n=len(torch.unique(data.batch)))
            accumulate_start_id = 0
            for i in range(len(torch.unique(data.batch))):
                pred_root_id = torch.argmax(
                    pre_label[accumulate_start_id:accumulate_start_id +
                              data.num_joint[i]]).item()
                gt_root_id = torch.argmax(
                    label[accumulate_start_id:accumulate_start_id +
                          data.num_joint[i]]).item()
                if pred_root_id == gt_root_id:
                    acc_total += 1.0
                accumulate_start_id += data.num_joint[i]
    return loss_meter.avg, acc_total / loss_meter.count
コード例 #6
0
def test(test_loader, model, args, save_result=False, best_epoch=None):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    outdir = args.checkpoint.split('/')[1]
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            if args.arch == 'masknet':
                mask_pred = model(data)
                mask_gt = data.mask.unsqueeze(1)
                loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    mask_pred, mask_gt.float(), reduction='mean')
            elif args.arch == 'jointnet':
                data_displacement = model(data)
                y_pred = data_displacement + data.pos
                loss = 0.0
                for i in range(len(torch.unique(data.batch))):
                    y_gt_sample = data.y[data.batch == i, :]
                    y_gt_sample = y_gt_sample[:data.num_joint[i], :]
                    y_pred_sample = y_pred[data.batch == i, :]
                    loss += chamfer_distance_with_average(
                        y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0))
            loss_meter.update(loss.item())

            if save_result:
                output_folder = 'results/{:s}/best_{:d}/'.format(
                    outdir, best_epoch)
                if not os.path.exists(output_folder):
                    mkdir_p(output_folder)
                if args.arch == 'masknet':
                    mask_pred = torch.sigmoid(mask_pred)
                    for i in range(len(torch.unique(data.batch))):
                        mask_pred_sample = mask_pred[data.batch == i]
                        np.save(
                            os.path.join(
                                output_folder,
                                str(data.name[i].item()) + '_attn.npy'),
                            mask_pred_sample.data.cpu().numpy())
                else:
                    for i in range(len(torch.unique(data.batch))):
                        y_pred_sample = y_pred[data.batch == i, :]
                        output_point_cloud_ply(
                            y_pred_sample,
                            name=str(data.name[i].item()),
                            output_folder='results/{:s}/best_{:d}/'.format(
                                outdir, best_epoch))
    return loss_meter.avg
コード例 #7
0
ファイル: run_skinning.py プロジェクト: zhan-xu/RigNet
def test(test_loader, model, args, save_result=False):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    outdir = args.checkpoint.split('/')[-1]
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            skin_pred = model(data)
            skin_gt = data.skin_label[:, 0:args.nearest_bone]
            loss_mask_batch = data.loss_mask.float()[:, 0:args.nearest_bone]
            skin_gt = skin_gt * loss_mask_batch
            skin_gt = skin_gt / (torch.sum(torch.abs(skin_gt), dim=1, keepdim=True) + 1e-8)
            vert_mask = (torch.abs(skin_gt.sum(dim=1) - 1.0) < 1e-8).float()
            loss = cross_entropy_with_probs(skin_pred, skin_gt, reduction='none')
            loss = (loss * loss_mask_batch * vert_mask.unsqueeze(1)).sum() / (loss_mask_batch * vert_mask.unsqueeze(1)).sum()
            loss_meter.update(loss.item())

            if save_result:
                output_folder = 'results/{:s}/'.format(outdir)
                if not os.path.exists(output_folder):
                    mkdir_p(output_folder)
                for i in range(len(torch.unique(data.batch))):
                    print('output result for model {:d}'.format(data.name[i].item()))
                    skin_pred_i = skin_pred[data.batch == i]
                    bone_names = get_bone_names(os.path.join(args.test_folder, "{:d}_skin.txt".format(data.name[i].item())))
                    tpl_e = np.loadtxt(os.path.join(args.test_folder, "{:d}_tpl_e.txt".format(data.name[i].item()))).T
                    loss_mask_sample = data.loss_mask.float()[data.batch == i, 0:args.nearest_bone]
                    skin_pred_i = torch.softmax(skin_pred_i, dim=1)
                    skin_pred_i = skin_pred_i * loss_mask_sample
                    skin_nn_i = data.skin_nn[data.batch == i, 0:args.nearest_bone]
                    skin_pred_asarray = np.zeros((len(skin_pred_i), len(bone_names)))
                    for v in range(len(skin_pred_i)):
                        for nn_id in range(len(skin_nn_i[v, :])):
                            skin_pred_asarray[v, skin_nn_i[v, nn_id]] = skin_pred_i[v, nn_id]
                    skin_pred_asarray = post_filter(skin_pred_asarray, tpl_e, num_ring=1)
                    skin_pred_asarray[skin_pred_asarray < np.max(skin_pred_asarray, axis=1, keepdims=True) * 0.5] = 0.0
                    skin_pred_asarray = skin_pred_asarray / (skin_pred_asarray.sum(axis=1, keepdims=True) + 1e-10)
                    with open(os.path.join(output_folder, "{:d}_bone_names.txt".format(data.name[i].item())), 'w') as fout:
                        for bone_name in bone_names:
                            fout.write("{:s} {:s}\n".format(bone_name[0], bone_name[1]))
                    np.save(os.path.join(output_folder, "{:d}_full_pred.npy".format(data.name[i].item())), skin_pred_asarray)
                    skel_filename = os.path.join(args.info_folder, "{:d}.txt".format(data.name[i].item()))
                    output_rigging(skel_filename, skin_pred_asarray, output_folder, data.name[i].item())
    return loss_meter.avg
コード例 #8
0
ファイル: run_root_cls.py プロジェクト: zhan-xu/RigNet
def train(train_loader, model, optimizer, args):
    global device
    model.train()  # switch to train mode
    loss_meter = AverageMeter()
    for data in train_loader:
        #print(data.name)
        data = data.to(device)
        optimizer.zero_grad()
        pre_label, label = model(data)
        loss_1 = torch.nn.functional.binary_cross_entropy_with_logits(pre_label, label, reduction='none')
        topk_val, _ = torch.topk(loss_1.view(-1), k=int(args.topk * len(pre_label)), dim=0, sorted=False)
        loss2 = topk_val.mean()
        #loss_3 = torch.nn.functional.binary_cross_entropy_with_logits(pre_label, label)
        loss = loss_1.mean() + loss2
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.item())
    return loss_meter.avg
コード例 #9
0
ファイル: run_skinning.py プロジェクト: zhan-xu/RigNet
def train(train_loader, model, optimizer, args):
    global device
    model.train()  # switch to train mode
    loss_meter = AverageMeter()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        skin_pred = model(data)
        skin_gt = data.skin_label[:, 0:args.nearest_bone]
        loss_mask_batch = data.loss_mask.float()[:, 0:args.nearest_bone]
        skin_gt = skin_gt * loss_mask_batch
        skin_gt = skin_gt / (torch.sum(torch.abs(skin_gt), dim=1, keepdim=True) + 1e-8)
        vert_mask = (torch.abs(skin_gt.sum(dim=1) - 1.0) < 1e-8).float()  # mask out vertices whose skinning is missing from the picked K bones.
        loss = cross_entropy_with_probs(skin_pred, skin_gt, reduction='none')
        loss = (loss * loss_mask_batch * vert_mask.unsqueeze(1)).sum() / (loss_mask_batch * vert_mask.unsqueeze(1)).sum()
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.item())
    return loss_meter.avg
コード例 #10
0
def train(train_loader, model, criterion, optimizer, epoch):
    
    # log
    loss_log = AverageMeter()
    bar = Bar('Training', max=len(train_loader))

    model.train()
    for i, (inputs, target, mask) in enumerate(train_loader):
        
        # cuda
        inputs = inputs.cuda()
        target = target.cuda()
        mask = mask.cuda()

        # inference
        outputs = model(inputs)

        # calculate loss
        target = torch.masked_select(target, mask)
        loss = 0
        for output in outputs:
            output = torch.masked_select(output, mask)
            loss += criterion(output, target) / inputs.shape[0]
        loss_log.update(loss.item(), inputs.size(0))

        # update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # show progress
        bar.suffix = '({batch}/{size}) | Total: {total:} | ETA: {eta:} | Loss: {loss:.6f}'.format(
            batch=i + 1,
            size=len(train_loader),
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=loss_log.avg)
        bar.next()

    bar.finish()

    # save inference image
    cv2.imwrite('images/{0:06d}.jpg'.format(epoch), make_inference_image(inputs, outputs[-1], mask))
コード例 #11
0
def test(test_loader, model, args, save_result=False, best_epoch=None):
    global device
    model.eval()  # switch to test mode
    if save_result:
        output_folder = 'results/{:s}/best_{:d}/'.format(
            args.checkpoint.split('/')[1], best_epoch)
        if not os.path.exists(output_folder):
            mkdir_p(output_folder)
    loss_meter = AverageMeter()
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            pre_label, label = model(data)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                pre_label, label.float())
            if save_result:
                connect_prob = torch.sigmoid(pre_label)
                accumulate_start_id = 0
                for i in range(len(torch.unique(data.batch))):
                    pair_idx = data.pairs[
                        accumulate_start_id:accumulate_start_id +
                        data.num_pair[i]].long()
                    connect_prob_i = connect_prob[
                        accumulate_start_id:accumulate_start_id +
                        data.num_pair[i]]
                    accumulate_start_id += data.num_pair[i]
                    cost_matrix = np.zeros(
                        (data.num_joint[i], data.num_joint[i]))
                    pair_idx = pair_idx.data.cpu().numpy()
                    cost_matrix[pair_idx[:, 0],
                                pair_idx[:,
                                         1]] = connect_prob_i.data.cpu().numpy(
                                         ).squeeze()
                    cost_matrix = 1 - cost_matrix
                    print('saving: {:s}'.format(
                        str(data.name[i].item()) + '_cost.npy'))
                    np.save(
                        os.path.join(output_folder,
                                     str(data.name[i].item()) + '_cost.npy'),
                        cost_matrix)
            loss_meter.update(loss.item(), n=len(torch.unique(data.batch)))
    return loss_meter.avg