Exemplo n.º 1
0
def train(train_loader, model, optimizer, args):
    global device
    model.train()  # switch to train mode
    chamfer_loss = AverageMeter()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(
            data)
        y_pred = data_displacement + data.pos
        loss_chamfer = 0.0
        if args.use_bce:
            mask_gt = data.mask.unsqueeze(1)
            loss_chamfer += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(
                mask_pred_nosigmoid, mask_gt.float(), reduction='mean')
        for i in range(len(torch.unique(data.batch))):
            y_gt_sample = data.y[data.batch == i, :]
            y_gt_sample = y_gt_sample[:data.num_joint[i], :]
            y_pred_sample = y_pred[data.batch == i, :]
            mask_pred_sample = mask_pred[data.batch == i]
            loss_chamfer += chamfer_distance_with_average(
                y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0))
            clustered_pred = meanshift_cluster(y_pred_sample, bandwidth,
                                               mask_pred_sample, args)
            for j in range(args.meanshift_step):
                loss_chamfer += args.ms_loss_weight * chamfer_distance_with_average(
                    clustered_pred[j].unsqueeze(0), y_gt_sample.unsqueeze(0))
        loss_chamfer.backward()
        optimizer.step()
        chamfer_loss.update(loss_chamfer.item(),
                            n=len(torch.unique(data.batch)))
    return chamfer_loss.avg
Exemplo n.º 2
0
def test(test_loader, model, args, save_result=False, best_epoch=None):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    outdir = args.checkpoint.split('/')[-1]
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data)
            y_pred = data_displacement + data.pos
            loss_total = 0.0
            for i in range(len(torch.unique(data.batch))):
                joint_gt = data.joints[data.joints_batch == i, :]
                y_pred_i = y_pred[data.batch == i, :]
                mask_pred_i = mask_pred[data.batch == i]
                loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0))
                clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args)
                loss_ms = 0.0
                for j in range(args.meanshift_step):
                    loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0))
                loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step
                if save_result:
                    output_point_cloud_ply(y_pred_i, name=str(data.name[i].item()),
                                           output_folder='results/{:s}/best_{:d}/'.format(outdir, best_epoch))
                    np.save('results/{:s}/best_{:d}/{:d}_attn.npy'.format(outdir, best_epoch, data.name[i].item()),
                            mask_pred_i.data.to("cpu").numpy())
                    np.save('results/{:s}/best_{:d}/{:d}_bandwidth.npy'.format(outdir, best_epoch, data.name[i].item()),
                            bandwidth.data.to("cpu").numpy())
            loss_total /= len(torch.unique(data.batch))
            if args.use_bce:
                mask_gt = data.mask.unsqueeze(1)
                loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean')
            loss_meter.update(loss_total.item())
    return loss_meter.avg
Exemplo n.º 3
0
def train(train_loader, model, optimizer, args):
    global device
    model.train()  # switch to train mode
    loss_meter = AverageMeter()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data)
        y_pred = data_displacement + data.pos
        loss_total = 0.0
        for i in range(len(torch.unique(data.batch))):
            joint_gt = data.joints[data.joints_batch == i, :]
            y_pred_i = y_pred[data.batch == i, :]
            mask_pred_i = mask_pred[data.batch == i]
            loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0))
            clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args)
            loss_ms = 0.0
            for j in range(args.meanshift_step):
                loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0))
            loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step
        loss_total /= len(torch.unique(data.batch))
        if args.use_bce:
            mask_gt = data.mask.unsqueeze(1)
            loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean')
        loss_total.backward()
        optimizer.step()
        loss_meter.update(loss_total.item())
    return loss_meter.avg
Exemplo n.º 4
0
def train(train_loader, model, optimizer, args, epoch):
    global device
    model.train()  # switch to train mode
    loss_meter = AverageMeter()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        if args.arch == 'masknet':
            mask_pred = model(data)
            mask_gt = data.mask.unsqueeze(1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                mask_pred, mask_gt.float(), reduction='mean')
        elif args.arch == 'jointnet':
            data_displacement = model(data)
            y_pred = data_displacement + data.pos
            loss = 0.0
            for i in range(len(torch.unique(data.batch))):
                y_gt_sample = data.y[data.batch == i, :]
                y_gt_sample = y_gt_sample[:data.num_joint[i], :]
                y_pred_sample = y_pred[data.batch == i, :]
                loss += chamfer_distance_with_average(
                    y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0))
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.item())
    return loss_meter.avg
Exemplo n.º 5
0
def test(test_loader, model, args, save_result=False, best_epoch=None):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    outdir = args.checkpoint.split('/')[1]
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            if args.arch == 'masknet':
                mask_pred = model(data)
                mask_gt = data.mask.unsqueeze(1)
                loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    mask_pred, mask_gt.float(), reduction='mean')
            elif args.arch == 'jointnet':
                data_displacement = model(data)
                y_pred = data_displacement + data.pos
                loss = 0.0
                for i in range(len(torch.unique(data.batch))):
                    y_gt_sample = data.y[data.batch == i, :]
                    y_gt_sample = y_gt_sample[:data.num_joint[i], :]
                    y_pred_sample = y_pred[data.batch == i, :]
                    loss += chamfer_distance_with_average(
                        y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0))
            loss_meter.update(loss.item())

            if save_result:
                output_folder = 'results/{:s}/best_{:d}/'.format(
                    outdir, best_epoch)
                if not os.path.exists(output_folder):
                    mkdir_p(output_folder)
                if args.arch == 'masknet':
                    mask_pred = torch.sigmoid(mask_pred)
                    for i in range(len(torch.unique(data.batch))):
                        mask_pred_sample = mask_pred[data.batch == i]
                        np.save(
                            os.path.join(
                                output_folder,
                                str(data.name[i].item()) + '_attn.npy'),
                            mask_pred_sample.data.cpu().numpy())
                else:
                    for i in range(len(torch.unique(data.batch))):
                        y_pred_sample = y_pred[data.batch == i, :]
                        output_point_cloud_ply(
                            y_pred_sample,
                            name=str(data.name[i].item()),
                            output_folder='results/{:s}/best_{:d}/'.format(
                                outdir, best_epoch))
    return loss_meter.avg