def train(train_loader, model, optimizer, args): global device model.train() # switch to train mode chamfer_loss = AverageMeter() for data in train_loader: data = data.to(device) optimizer.zero_grad() data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model( data) y_pred = data_displacement + data.pos loss_chamfer = 0.0 if args.use_bce: mask_gt = data.mask.unsqueeze(1) loss_chamfer += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits( mask_pred_nosigmoid, mask_gt.float(), reduction='mean') for i in range(len(torch.unique(data.batch))): y_gt_sample = data.y[data.batch == i, :] y_gt_sample = y_gt_sample[:data.num_joint[i], :] y_pred_sample = y_pred[data.batch == i, :] mask_pred_sample = mask_pred[data.batch == i] loss_chamfer += chamfer_distance_with_average( y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0)) clustered_pred = meanshift_cluster(y_pred_sample, bandwidth, mask_pred_sample, args) for j in range(args.meanshift_step): loss_chamfer += args.ms_loss_weight * chamfer_distance_with_average( clustered_pred[j].unsqueeze(0), y_gt_sample.unsqueeze(0)) loss_chamfer.backward() optimizer.step() chamfer_loss.update(loss_chamfer.item(), n=len(torch.unique(data.batch))) return chamfer_loss.avg
def test(test_loader, model, args, save_result=False, best_epoch=None): global device model.eval() # switch to test mode loss_meter = AverageMeter() outdir = args.checkpoint.split('/')[-1] for data in test_loader: data = data.to(device) with torch.no_grad(): data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data) y_pred = data_displacement + data.pos loss_total = 0.0 for i in range(len(torch.unique(data.batch))): joint_gt = data.joints[data.joints_batch == i, :] y_pred_i = y_pred[data.batch == i, :] mask_pred_i = mask_pred[data.batch == i] loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0)) clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args) loss_ms = 0.0 for j in range(args.meanshift_step): loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0)) loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step if save_result: output_point_cloud_ply(y_pred_i, name=str(data.name[i].item()), output_folder='results/{:s}/best_{:d}/'.format(outdir, best_epoch)) np.save('results/{:s}/best_{:d}/{:d}_attn.npy'.format(outdir, best_epoch, data.name[i].item()), mask_pred_i.data.to("cpu").numpy()) np.save('results/{:s}/best_{:d}/{:d}_bandwidth.npy'.format(outdir, best_epoch, data.name[i].item()), bandwidth.data.to("cpu").numpy()) loss_total /= len(torch.unique(data.batch)) if args.use_bce: mask_gt = data.mask.unsqueeze(1) loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean') loss_meter.update(loss_total.item()) return loss_meter.avg
def train(train_loader, model, optimizer, args): global device model.train() # switch to train mode loss_meter = AverageMeter() for data in train_loader: data = data.to(device) optimizer.zero_grad() data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data) y_pred = data_displacement + data.pos loss_total = 0.0 for i in range(len(torch.unique(data.batch))): joint_gt = data.joints[data.joints_batch == i, :] y_pred_i = y_pred[data.batch == i, :] mask_pred_i = mask_pred[data.batch == i] loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0)) clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args) loss_ms = 0.0 for j in range(args.meanshift_step): loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0)) loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step loss_total /= len(torch.unique(data.batch)) if args.use_bce: mask_gt = data.mask.unsqueeze(1) loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean') loss_total.backward() optimizer.step() loss_meter.update(loss_total.item()) return loss_meter.avg
def train(train_loader, model, optimizer, args, epoch): global device model.train() # switch to train mode loss_meter = AverageMeter() for data in train_loader: data = data.to(device) optimizer.zero_grad() if args.arch == 'masknet': mask_pred = model(data) mask_gt = data.mask.unsqueeze(1) loss = torch.nn.functional.binary_cross_entropy_with_logits( mask_pred, mask_gt.float(), reduction='mean') elif args.arch == 'jointnet': data_displacement = model(data) y_pred = data_displacement + data.pos loss = 0.0 for i in range(len(torch.unique(data.batch))): y_gt_sample = data.y[data.batch == i, :] y_gt_sample = y_gt_sample[:data.num_joint[i], :] y_pred_sample = y_pred[data.batch == i, :] loss += chamfer_distance_with_average( y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0)) loss.backward() optimizer.step() loss_meter.update(loss.item()) return loss_meter.avg
def test(test_loader, model, args, save_result=False, best_epoch=None): global device model.eval() # switch to test mode loss_meter = AverageMeter() outdir = args.checkpoint.split('/')[1] for data in test_loader: data = data.to(device) with torch.no_grad(): if args.arch == 'masknet': mask_pred = model(data) mask_gt = data.mask.unsqueeze(1) loss = torch.nn.functional.binary_cross_entropy_with_logits( mask_pred, mask_gt.float(), reduction='mean') elif args.arch == 'jointnet': data_displacement = model(data) y_pred = data_displacement + data.pos loss = 0.0 for i in range(len(torch.unique(data.batch))): y_gt_sample = data.y[data.batch == i, :] y_gt_sample = y_gt_sample[:data.num_joint[i], :] y_pred_sample = y_pred[data.batch == i, :] loss += chamfer_distance_with_average( y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0)) loss_meter.update(loss.item()) if save_result: output_folder = 'results/{:s}/best_{:d}/'.format( outdir, best_epoch) if not os.path.exists(output_folder): mkdir_p(output_folder) if args.arch == 'masknet': mask_pred = torch.sigmoid(mask_pred) for i in range(len(torch.unique(data.batch))): mask_pred_sample = mask_pred[data.batch == i] np.save( os.path.join( output_folder, str(data.name[i].item()) + '_attn.npy'), mask_pred_sample.data.cpu().numpy()) else: for i in range(len(torch.unique(data.batch))): y_pred_sample = y_pred[data.batch == i, :] output_point_cloud_ply( y_pred_sample, name=str(data.name[i].item()), output_folder='results/{:s}/best_{:d}/'.format( outdir, best_epoch)) return loss_meter.avg