def train(train_loader, model, optimizer, args): global device model.train() # switch to train mode chamfer_loss = AverageMeter() for data in train_loader: data = data.to(device) optimizer.zero_grad() data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model( data) y_pred = data_displacement + data.pos loss_chamfer = 0.0 if args.use_bce: mask_gt = data.mask.unsqueeze(1) loss_chamfer += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits( mask_pred_nosigmoid, mask_gt.float(), reduction='mean') for i in range(len(torch.unique(data.batch))): y_gt_sample = data.y[data.batch == i, :] y_gt_sample = y_gt_sample[:data.num_joint[i], :] y_pred_sample = y_pred[data.batch == i, :] mask_pred_sample = mask_pred[data.batch == i] loss_chamfer += chamfer_distance_with_average( y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0)) clustered_pred = meanshift_cluster(y_pred_sample, bandwidth, mask_pred_sample, args) for j in range(args.meanshift_step): loss_chamfer += args.ms_loss_weight * chamfer_distance_with_average( clustered_pred[j].unsqueeze(0), y_gt_sample.unsqueeze(0)) loss_chamfer.backward() optimizer.step() chamfer_loss.update(loss_chamfer.item(), n=len(torch.unique(data.batch))) return chamfer_loss.avg
def train(train_loader, model, optimizer, args, epoch): global device model.train() # switch to train mode loss_meter = AverageMeter() for data in train_loader: data = data.to(device) optimizer.zero_grad() if args.arch == 'masknet': mask_pred = model(data) mask_gt = data.mask.unsqueeze(1) loss = torch.nn.functional.binary_cross_entropy_with_logits( mask_pred, mask_gt.float(), reduction='mean') elif args.arch == 'jointnet': data_displacement = model(data) y_pred = data_displacement + data.pos loss = 0.0 for i in range(len(torch.unique(data.batch))): y_gt_sample = data.y[data.batch == i, :] y_gt_sample = y_gt_sample[:data.num_joint[i], :] y_pred_sample = y_pred[data.batch == i, :] loss += chamfer_distance_with_average( y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0)) loss.backward() optimizer.step() loss_meter.update(loss.item()) return loss_meter.avg
def test(test_loader, model, args, save_result=False, best_epoch=None): global device model.eval() # switch to test mode loss_meter = AverageMeter() outdir = args.checkpoint.split('/')[-1] for data in test_loader: data = data.to(device) with torch.no_grad(): data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data) y_pred = data_displacement + data.pos loss_total = 0.0 for i in range(len(torch.unique(data.batch))): joint_gt = data.joints[data.joints_batch == i, :] y_pred_i = y_pred[data.batch == i, :] mask_pred_i = mask_pred[data.batch == i] loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0)) clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args) loss_ms = 0.0 for j in range(args.meanshift_step): loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0)) loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step if save_result: output_point_cloud_ply(y_pred_i, name=str(data.name[i].item()), output_folder='results/{:s}/best_{:d}/'.format(outdir, best_epoch)) np.save('results/{:s}/best_{:d}/{:d}_attn.npy'.format(outdir, best_epoch, data.name[i].item()), mask_pred_i.data.to("cpu").numpy()) np.save('results/{:s}/best_{:d}/{:d}_bandwidth.npy'.format(outdir, best_epoch, data.name[i].item()), bandwidth.data.to("cpu").numpy()) loss_total /= len(torch.unique(data.batch)) if args.use_bce: mask_gt = data.mask.unsqueeze(1) loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean') loss_meter.update(loss_total.item()) return loss_meter.avg
def train(train_loader, model, optimizer, args): global device model.train() # switch to train mode loss_meter = AverageMeter() for data in train_loader: data = data.to(device) optimizer.zero_grad() data_displacement, mask_pred_nosigmoid, mask_pred, bandwidth = model(data) y_pred = data_displacement + data.pos loss_total = 0.0 for i in range(len(torch.unique(data.batch))): joint_gt = data.joints[data.joints_batch == i, :] y_pred_i = y_pred[data.batch == i, :] mask_pred_i = mask_pred[data.batch == i] loss_total += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joint_gt.unsqueeze(0)) clustered_pred = meanshift_cluster(y_pred_i, bandwidth, mask_pred_i, args) loss_ms = 0.0 for j in range(args.meanshift_step): loss_ms += chamfer_distance_with_average(clustered_pred[j].unsqueeze(0), joint_gt.unsqueeze(0)) loss_total = loss_total + args.ms_loss_weight * loss_ms / args.meanshift_step loss_total /= len(torch.unique(data.batch)) if args.use_bce: mask_gt = data.mask.unsqueeze(1) loss_total += args.bce_loss_weight * torch.nn.functional.binary_cross_entropy_with_logits(mask_pred_nosigmoid, mask_gt.float(), reduction='mean') loss_total.backward() optimizer.step() loss_meter.update(loss_total.item()) return loss_meter.avg
def test(test_loader, model, args): global device model.eval() # switch to test mode loss_meter = AverageMeter() acc_total = 0.0 for data in test_loader: #print(data.name) data = data.to(device) with torch.no_grad(): pre_label, label = model(data) loss = torch.nn.functional.binary_cross_entropy_with_logits( pre_label, label.float()) loss_meter.update(loss.item(), n=len(torch.unique(data.batch))) accumulate_start_id = 0 for i in range(len(torch.unique(data.batch))): pred_root_id = torch.argmax( pre_label[accumulate_start_id:accumulate_start_id + data.num_joint[i]]).item() gt_root_id = torch.argmax( label[accumulate_start_id:accumulate_start_id + data.num_joint[i]]).item() if pred_root_id == gt_root_id: acc_total += 1.0 accumulate_start_id += data.num_joint[i] return loss_meter.avg, acc_total / loss_meter.count
def test(test_loader, model, args, save_result=False, best_epoch=None): global device model.eval() # switch to test mode loss_meter = AverageMeter() outdir = args.checkpoint.split('/')[1] for data in test_loader: data = data.to(device) with torch.no_grad(): if args.arch == 'masknet': mask_pred = model(data) mask_gt = data.mask.unsqueeze(1) loss = torch.nn.functional.binary_cross_entropy_with_logits( mask_pred, mask_gt.float(), reduction='mean') elif args.arch == 'jointnet': data_displacement = model(data) y_pred = data_displacement + data.pos loss = 0.0 for i in range(len(torch.unique(data.batch))): y_gt_sample = data.y[data.batch == i, :] y_gt_sample = y_gt_sample[:data.num_joint[i], :] y_pred_sample = y_pred[data.batch == i, :] loss += chamfer_distance_with_average( y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0)) loss_meter.update(loss.item()) if save_result: output_folder = 'results/{:s}/best_{:d}/'.format( outdir, best_epoch) if not os.path.exists(output_folder): mkdir_p(output_folder) if args.arch == 'masknet': mask_pred = torch.sigmoid(mask_pred) for i in range(len(torch.unique(data.batch))): mask_pred_sample = mask_pred[data.batch == i] np.save( os.path.join( output_folder, str(data.name[i].item()) + '_attn.npy'), mask_pred_sample.data.cpu().numpy()) else: for i in range(len(torch.unique(data.batch))): y_pred_sample = y_pred[data.batch == i, :] output_point_cloud_ply( y_pred_sample, name=str(data.name[i].item()), output_folder='results/{:s}/best_{:d}/'.format( outdir, best_epoch)) return loss_meter.avg
def test(test_loader, model, args, save_result=False): global device model.eval() # switch to test mode loss_meter = AverageMeter() outdir = args.checkpoint.split('/')[-1] for data in test_loader: data = data.to(device) with torch.no_grad(): skin_pred = model(data) skin_gt = data.skin_label[:, 0:args.nearest_bone] loss_mask_batch = data.loss_mask.float()[:, 0:args.nearest_bone] skin_gt = skin_gt * loss_mask_batch skin_gt = skin_gt / (torch.sum(torch.abs(skin_gt), dim=1, keepdim=True) + 1e-8) vert_mask = (torch.abs(skin_gt.sum(dim=1) - 1.0) < 1e-8).float() loss = cross_entropy_with_probs(skin_pred, skin_gt, reduction='none') loss = (loss * loss_mask_batch * vert_mask.unsqueeze(1)).sum() / (loss_mask_batch * vert_mask.unsqueeze(1)).sum() loss_meter.update(loss.item()) if save_result: output_folder = 'results/{:s}/'.format(outdir) if not os.path.exists(output_folder): mkdir_p(output_folder) for i in range(len(torch.unique(data.batch))): print('output result for model {:d}'.format(data.name[i].item())) skin_pred_i = skin_pred[data.batch == i] bone_names = get_bone_names(os.path.join(args.test_folder, "{:d}_skin.txt".format(data.name[i].item()))) tpl_e = np.loadtxt(os.path.join(args.test_folder, "{:d}_tpl_e.txt".format(data.name[i].item()))).T loss_mask_sample = data.loss_mask.float()[data.batch == i, 0:args.nearest_bone] skin_pred_i = torch.softmax(skin_pred_i, dim=1) skin_pred_i = skin_pred_i * loss_mask_sample skin_nn_i = data.skin_nn[data.batch == i, 0:args.nearest_bone] skin_pred_asarray = np.zeros((len(skin_pred_i), len(bone_names))) for v in range(len(skin_pred_i)): for nn_id in range(len(skin_nn_i[v, :])): skin_pred_asarray[v, skin_nn_i[v, nn_id]] = skin_pred_i[v, nn_id] skin_pred_asarray = post_filter(skin_pred_asarray, tpl_e, num_ring=1) skin_pred_asarray[skin_pred_asarray < np.max(skin_pred_asarray, axis=1, keepdims=True) * 0.5] = 0.0 skin_pred_asarray = skin_pred_asarray / (skin_pred_asarray.sum(axis=1, keepdims=True) + 1e-10) with open(os.path.join(output_folder, "{:d}_bone_names.txt".format(data.name[i].item())), 'w') as fout: for bone_name in bone_names: fout.write("{:s} {:s}\n".format(bone_name[0], bone_name[1])) np.save(os.path.join(output_folder, "{:d}_full_pred.npy".format(data.name[i].item())), skin_pred_asarray) skel_filename = os.path.join(args.info_folder, "{:d}.txt".format(data.name[i].item())) output_rigging(skel_filename, skin_pred_asarray, output_folder, data.name[i].item()) return loss_meter.avg
def train(train_loader, model, optimizer, args): global device model.train() # switch to train mode loss_meter = AverageMeter() for data in train_loader: #print(data.name) data = data.to(device) optimizer.zero_grad() pre_label, label = model(data) loss_1 = torch.nn.functional.binary_cross_entropy_with_logits(pre_label, label, reduction='none') topk_val, _ = torch.topk(loss_1.view(-1), k=int(args.topk * len(pre_label)), dim=0, sorted=False) loss2 = topk_val.mean() #loss_3 = torch.nn.functional.binary_cross_entropy_with_logits(pre_label, label) loss = loss_1.mean() + loss2 loss.backward() optimizer.step() loss_meter.update(loss.item()) return loss_meter.avg
def train(train_loader, model, optimizer, args): global device model.train() # switch to train mode loss_meter = AverageMeter() for data in train_loader: data = data.to(device) optimizer.zero_grad() skin_pred = model(data) skin_gt = data.skin_label[:, 0:args.nearest_bone] loss_mask_batch = data.loss_mask.float()[:, 0:args.nearest_bone] skin_gt = skin_gt * loss_mask_batch skin_gt = skin_gt / (torch.sum(torch.abs(skin_gt), dim=1, keepdim=True) + 1e-8) vert_mask = (torch.abs(skin_gt.sum(dim=1) - 1.0) < 1e-8).float() # mask out vertices whose skinning is missing from the picked K bones. loss = cross_entropy_with_probs(skin_pred, skin_gt, reduction='none') loss = (loss * loss_mask_batch * vert_mask.unsqueeze(1)).sum() / (loss_mask_batch * vert_mask.unsqueeze(1)).sum() loss.backward() optimizer.step() loss_meter.update(loss.item()) return loss_meter.avg
def train(train_loader, model, criterion, optimizer, epoch): # log loss_log = AverageMeter() bar = Bar('Training', max=len(train_loader)) model.train() for i, (inputs, target, mask) in enumerate(train_loader): # cuda inputs = inputs.cuda() target = target.cuda() mask = mask.cuda() # inference outputs = model(inputs) # calculate loss target = torch.masked_select(target, mask) loss = 0 for output in outputs: output = torch.masked_select(output, mask) loss += criterion(output, target) / inputs.shape[0] loss_log.update(loss.item(), inputs.size(0)) # update weights optimizer.zero_grad() loss.backward() optimizer.step() # show progress bar.suffix = '({batch}/{size}) | Total: {total:} | ETA: {eta:} | Loss: {loss:.6f}'.format( batch=i + 1, size=len(train_loader), total=bar.elapsed_td, eta=bar.eta_td, loss=loss_log.avg) bar.next() bar.finish() # save inference image cv2.imwrite('images/{0:06d}.jpg'.format(epoch), make_inference_image(inputs, outputs[-1], mask))
def test(test_loader, model, args, save_result=False, best_epoch=None): global device model.eval() # switch to test mode if save_result: output_folder = 'results/{:s}/best_{:d}/'.format( args.checkpoint.split('/')[1], best_epoch) if not os.path.exists(output_folder): mkdir_p(output_folder) loss_meter = AverageMeter() for data in test_loader: data = data.to(device) with torch.no_grad(): pre_label, label = model(data) loss = torch.nn.functional.binary_cross_entropy_with_logits( pre_label, label.float()) if save_result: connect_prob = torch.sigmoid(pre_label) accumulate_start_id = 0 for i in range(len(torch.unique(data.batch))): pair_idx = data.pairs[ accumulate_start_id:accumulate_start_id + data.num_pair[i]].long() connect_prob_i = connect_prob[ accumulate_start_id:accumulate_start_id + data.num_pair[i]] accumulate_start_id += data.num_pair[i] cost_matrix = np.zeros( (data.num_joint[i], data.num_joint[i])) pair_idx = pair_idx.data.cpu().numpy() cost_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob_i.data.cpu().numpy( ).squeeze() cost_matrix = 1 - cost_matrix print('saving: {:s}'.format( str(data.name[i].item()) + '_cost.npy')) np.save( os.path.join(output_folder, str(data.name[i].item()) + '_cost.npy'), cost_matrix) loss_meter.update(loss.item(), n=len(torch.unique(data.batch))) return loss_meter.avg