Example #1
0
def compute_loss(pred, target, outputs, output_pred_ind, output_target_ind,
                 output_loss_weight, patch_rot, normal_loss):

    loss = 0

    for oi, o in enumerate(outputs):
        if o == 'unoriented_normals' or o == 'oriented_normals':
            o_pred = pred[:, output_pred_ind[oi]:output_pred_ind[oi] + 3]
            o_target = target[output_target_ind[oi]]

            if patch_rot is not None:
                # transform predictions with inverse transform
                # since we know the transform to be a rotation (QSTN), the transpose is the inverse
                o_pred = torch.bmm(o_pred.unsqueeze(1),
                                   patch_rot.transpose(2, 1)).squeeze(1)

            if o == 'unoriented_normals':
                if normal_loss == 'ms_euclidean':
                    loss += torch.min(
                        (o_pred - o_target).pow(2).sum(1),
                        (o_pred + o_target
                         ).pow(2).sum(1)).mean() * output_loss_weight[oi]
                elif normal_loss == 'ms_oneminuscos':
                    loss += (1 - torch.abs(utils.cos_angle(o_pred, o_target))
                             ).pow(2).mean() * output_loss_weight[oi]
                else:
                    raise ValueError('Unsupported loss type: %s' %
                                     (normal_loss))
            elif o == 'oriented_normals':
                if normal_loss == 'ms_euclidean':
                    loss += (o_pred - o_target
                             ).pow(2).sum(1).mean() * output_loss_weight[oi]
                elif normal_loss == 'ms_oneminuscos':
                    loss += (1 - utils.cos_angle(o_pred, o_target)
                             ).pow(2).mean() * output_loss_weight[oi]
                else:
                    raise ValueError('Unsupported loss type: %s' %
                                     (normal_loss))
            else:
                raise ValueError('Unsupported output type: %s' % (o))

        elif o == 'max_curvature' or o == 'min_curvature':
            o_pred = pred[:, output_pred_ind[oi]:output_pred_ind[oi] + 1]
            o_target = target[output_target_ind[oi]]

            # Rectified mse loss: mean square of (pred - gt) / max(1, |gt|)
            normalized_diff = (o_pred - o_target) / torch.clamp(
                torch.abs(o_target), min=1)
            loss += normalized_diff.pow(2).mean() * output_loss_weight[oi]

        else:
            raise ValueError('Unsupported output type: %s' % (o))

    return loss
Example #2
0
def compute_loss(pred, target, outputs, output_pred_ind, output_target_ind,
                 output_loss_weight, patch_rot, normal_loss):

    assert len(list(enumerate(outputs))) == 1, "bad number of outputs"

    losses = []
    for oi, o in enumerate(outputs):
        if o == 'unoriented_normals':
            o_pred = pred[:, output_pred_ind[oi]:output_pred_ind[oi] + 3]
            o_target = target[output_target_ind[oi]]

            if patch_rot is not None:
                # transform predictions with inverse transform
                # since we know the transform to be a rotation (QSTN), the transpose is the inverse
                o_pred = torch.bmm(o_pred.unsqueeze(1),
                                   patch_rot.transpose(2, 1)).squeeze(1)

            if normal_loss == 'ms_oneminuscos':
                l = (1 - torch.abs(utils.cos_angle(
                    o_pred, o_target))).pow(2) * output_loss_weight[oi]
                ll = [l[i] for i in range(l.shape[0])]
                losses.extend(ll)
            else:
                raise ValueError('Unsupported output type: %s' % (o))
        else:
            raise ValueError('Unsupported output type: %s' % (o))

    return losses
Example #3
0
def compute_loss(pred, target,  normal_loss):

    if normal_loss == 'ms_euclidean':
        loss = torch.min((pred-target).pow(2).sum(2), (pred+target).pow(2).sum(2))#* output_loss_weight
    elif normal_loss == 'ms_oneminuscos':
        loss = (1-torch.abs(utils.cos_angle(pred, target))).pow(2)#* output_loss_weight
    else:
        raise ValueError('Unsupported loss type: %s' % (normal_loss))

    return loss
Example #4
0
def compute_loss(pred, target, normal_loss='ms_euclidean'):

    # if patch_rot is not None:
    #     # transform predictions with inverse transform
    #     # since we know the transform to be a rotation (QSTN), the transpose is the inverse
    #     target = torch.bmm(target, patch_rot.transpose(2, 1))#.squeeze(1)

    if normal_loss == 'ms_euclidean':
        loss = torch.min((pred - target).pow(2).sum(2),
                         (pred + target).pow(2).sum(2))  #* output_loss_weight
    elif normal_loss == 'ms_oneminuscos':
        loss = (1 - torch.abs(utils.cos_angle(pred, target))).pow(
            2)  #* output_loss_weight
    else:
        raise ValueError('Unsupported loss type: %s' % (normal_loss))

    return loss
Example #5
0
def compute_loss(output, target, loss_type, normalize):

    loss = 0

    if normalize:
        output = F.normalize(output, dim=1)
        target = F.normalize(target, dim=1)

    if loss_type == 'mse_loss':
        loss += F.mse_loss(output, target)
    elif loss_type == 'ms_euclidean':
        loss += torch.min((output - target).pow(2).sum(1),
                          (output + target).pow(2).sum(1)).mean()
    elif loss_type == 'ms_oneminuscos':
        loss += (1 - torch.abs(cos_angle(output, target))).pow(2).mean()
    else:
        raise ValueError('Unsupported loss type: {}'.format(loss_type))

    return loss
Example #6
0
def train(args):
    device = torch.device(
        'cpu' if args.gpu_idx < 0 else 'cuda:{}'.format(args.gpu_idx))

    outdir = args.path_model
    indir = os.path.join(args.path_model, args.path_dataset)

    learning_rate = args.lr
    batch_size = args.batch_size

    # Data Loading
    HMP, Nf, Ng, idx_train, idx_val = load_data(path=indir,
                                                id_cluster=args.id_cluster,
                                                split=args.rate_split)

    ds_loader_train, ds_loader_val = data_loader(
        HMP,
        Nf,
        Ng,
        idx_train,
        idx_val,
        BatchSize=args.batch_size,
        sampling=args.sampling_strategy)
    nfeatures = int(Nf.shape[1] / 3)

    # Create model
    net = Net(nfeatures)
    net.to(device)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          momentum=args.momentum)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=0.1)

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    log_filename = os.path.join(outdir,
                                'log_trainC{}.txt'.format(args.id_cluster))
    model_filename = os.path.join(
        outdir, 'model_cluster{}.pth'.format(args.id_cluster))
    args_filename = os.path.join(outdir,
                                 'args_cluster{}.pth'.format(args.id_cluster))

    # Training
    print("Training...")
    '''
    if os.path.exists(model_filename):
        response = input('A training instance ({}) already exists, overwrite? (y/n) '.format(model_filename))
        if response == 'y' or response == 'Y':
            if os.path.exists(log_filename):
                os.remove(log_filename)
            if os.path.exists(model_filename):
                os.remove(model_filename)
        else:
            print('Training exit.')
            sys.exit()'''

    if os.path.exists(model_filename):
        raise ValueError(
            'A training instance already exists: {}'.format(model_filename))

    # LOG
    LOG_file = open(log_filename, 'w')
    log_write(LOG_file, str(args))
    log_write(
        LOG_file, 'data size = {}, train size = {}, val size = {}'.format(
            HMP.shape[0], np.sum(idx_train), np.sum(idx_val)))
    log_write(LOG_file, '***************************\n')

    train_batch_num = len(ds_loader_train)
    val_batch_num = len(ds_loader_val)

    min_error = 180
    epoch_best = -1
    bad_counter = 0

    for epoch in range(args.max_epochs):

        loss_cnt = 0
        err_cnt = 0
        cnt = 0

        # update learning rate
        scheduler.step()

        learning_rate = optimizer.param_groups[0]['lr']

        log_write(LOG_file, 'EPOCH #{}'.format(str(epoch + 1)))
        log_write(LOG_file,
                  'lr = {}, batch size = {}'.format(learning_rate, batch_size))

        net.train()

        for i, inputs in enumerate(ds_loader_train):
            x, y, label = inputs

            x = x.to(device)
            y = y.to(device)
            label = label.to(device)

            optimizer.zero_grad()

            # forward backward
            output = net(x, y)

            loss = compute_loss(output,
                                label,
                                loss_type=args.normal_loss,
                                normalize=args.normalize_output)
            loss.backward()
            optimizer.step()

            cnt += x.size(0)
            loss_cnt += loss.item()
            err = torch.abs(cos_angle(output, label)).detach().cpu().numpy()
            err = np.rad2deg(np.arccos(err))
            err_cnt += np.sum(err)

        train_loss = loss_cnt / train_batch_num
        train_err = err_cnt / cnt

        # validate
        net.eval()

        loss_cnt = 0
        err_cnt = 0
        cnt = 0

        for i, inputs in enumerate(ds_loader_val):
            x, y, label = inputs

            x = x.to(device)
            y = y.to(device)
            label = label.to(device)

            # forward
            with torch.no_grad():
                output = net(x, y)

            loss = compute_loss(output,
                                label,
                                loss_type=args.normal_loss,
                                normalize=args.normalize_output)
            loss_cnt += loss.item()
            cnt += x.size(0)

            err = torch.abs(cos_angle(output, label)).detach().cpu().numpy()
            err = np.rad2deg(np.arccos(err))
            err_cnt += np.sum(err)

        val_loss = loss_cnt / val_batch_num
        val_err = err_cnt / cnt

        # log
        log_write(
            LOG_file,
            'train loss = {}, train error = {}'.format(train_loss, train_err))
        log_write(LOG_file,
                  'val loss = {}, val error = {}'.format(val_loss, val_err))

        if min_error > val_err:
            min_error = val_err
            epoch_best = epoch + 1
            bad_counter = 0
            log_write(LOG_file,
                      'Current best epoch #{} saved in file: {}'.format(
                          epoch_best, model_filename),
                      show_info=False)
            torch.save(net.state_dict(), model_filename)
        else:
            bad_counter += 1

        if bad_counter >= args.patience:
            break
Example #7
0
def compute_loss(pred1, pred2, pred3, target, mask_gt1, mask1, mask_gt2, mask2,
                 mask_gt3, mask3, outputs, output_pred_ind, output_target_ind,
                 output_loss_weight, patch_rot1, patch_rot2, patch_rot3,
                 normal_loss):

    loss = 0

    loss1_1 = torch.nn.functional.nll_loss(mask1, mask_gt1)
    loss1_2 = torch.nn.functional.nll_loss(mask2, mask_gt2)
    loss1_3 = torch.nn.functional.nll_loss(mask3, mask_gt3)

    loss1 = (loss1_1 + loss1_2 + loss1_3) / 3
    for oi, o in enumerate(outputs):
        if o == 'unoriented_normals' or o == 'oriented_normals':
            o_pred1 = pred1[:, output_pred_ind[oi]:output_pred_ind[oi] + 3]
            o_pred2 = pred2[:, output_pred_ind[oi]:output_pred_ind[oi] + 3]
            o_pred3 = pred3[:, output_pred_ind[oi]:output_pred_ind[oi] + 3]

            o_target = target[output_target_ind[oi]]

            if patch_rot1 is not None:
                # transform predictions with inverse transform
                # since we know the transform to be a rotation (QSTN), the transpose is the inverse
                o_pred1 = torch.bmm(o_pred1.unsqueeze(1),
                                    patch_rot1.transpose(2, 1)).squeeze(1)
                o_pred2 = torch.bmm(o_pred2.unsqueeze(1),
                                    patch_rot2.transpose(2, 1)).squeeze(1)
                o_pred3 = torch.bmm(o_pred3.unsqueeze(1),
                                    patch_rot3.transpose(2, 1)).squeeze(1)

            if o == 'unoriented_normals':
                if normal_loss == 'ms_euclidean':
                    loss2_1 = torch.min((o_pred1 - o_target).pow(2).sum(1),
                                        (o_pred1 + o_target).pow(2).sum(1))
                    loss2_2 = torch.min((o_pred2 - o_target).pow(2).sum(1),
                                        (o_pred2 + o_target).pow(2).sum(1))
                    loss2_3 = torch.min((o_pred3 - o_target).pow(2).sum(1),
                                        (o_pred3 + o_target).pow(2).sum(1))

                    tt = torch.mul(loss2_1, output_loss_weight[:,0])+\
                            torch.mul(loss2_2, output_loss_weight[:, 1])+\
                            torch.mul(loss2_3, output_loss_weight[:, 2])

                    loss += tt.mean()

                    # loss += torch.min((o_pred-o_target).pow(2).sum(1), (o_pred+o_target).pow(2).sum(1)).mean() * output_loss_weight[oi]
                elif normal_loss == 'ms_oneminuscos':
                    loss += (1 - torch.abs(utils.cos_angle(o_pred, o_target))
                             ).pow(2).mean() * output_loss_weight[oi]
                else:
                    raise ValueError('Unsupported loss type: %s' %
                                     (normal_loss))
            elif o == 'oriented_normals':
                if normal_loss == 'ms_euclidean':
                    loss += (o_pred - o_target
                             ).pow(2).sum(1).mean() * output_loss_weight[oi]
                elif normal_loss == 'ms_oneminuscos':
                    loss += (1 - utils.cos_angle(o_pred, o_target)
                             ).pow(2).mean() * output_loss_weight[oi]
                else:
                    raise ValueError('Unsupported loss type: %s' %
                                     (normal_loss))
            else:
                raise ValueError('Unsupported output type: %s' % (o))

        elif o == 'max_curvature' or o == 'min_curvature':
            o_pred = pred1[:, output_pred_ind[oi]:output_pred_ind[oi] + 1]
            o_target = target[output_target_ind[oi]]

            # Rectified mse loss: mean square of (pred - gt) / max(1, |gt|)
            normalized_diff = (o_pred - o_target) / torch.clamp(
                torch.abs(o_target), min=1)
            loss += normalized_diff.pow(2).mean() * output_loss_weight[oi]

        else:
            raise ValueError('Unsupported output type: %s' % (o))

    return 0.5 * loss + 0.5 * loss1, loss, loss1