for epoch in range(opt.nepoch): scheduler.step() for i, data in enumerate(dataloader, 0): points, target = data points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() pred, trans, trans_feat = classifier(points) pred = pred.view(-1, num_classes) target = target.view(-1, 1)[:, 0] - 1 #print(pred.size(), target.size()) loss = F.nll_loss(pred, target) if opt.feature_transform: loss += feature_transform_regularizer(trans_feat) * 0.001 loss.backward() optimizer.step() pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize * 2500))) if i % 10 == 0: j, data = next(enumerate(testdataloader, 0)) points, target = data points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() pred, _, _ = classifier(points)
def train(lr=0.001): parser = argparse.ArgumentParser() opt = parser.parse_args() opt.nepoch = 1 opt.batchsize = 18 opt.workers = 0 opt.outf = 'completion' opt.dataset = '/home/cdi0/data/shape_net_core_uniform_samples_2048_split/' opt.feature_transform = False opt.model = '' opt.device = 'cuda:1' opt.lr = lr opt.manualSeed = random.randint(1, 10000) # fix seed print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) dataset = ShapeNetDataset(dir=opt.dataset, ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers)) test_dataset = ShapeNetDataset( dir=opt.dataset, train='test', ) testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers)) print(len(dataset), len(test_dataset)) #try: # os.makedirs(opt.outf) #except OSError: # pass blue = lambda x: '\033[94m' + x + '\033[0m' device = opt.device netG = PointNetDenseCls(device=device, feature_transform=opt.feature_transform) localD = LocalDiscriminator(k=2, device=device) globalD = GlobalDiscriminator(k=2, device=device) if opt.model != '': netG.load_state_dict(torch.load(opt.model)) optimizerG = optim.Adam(netG.parameters(), lr=0.001, betas=(0.9, 0.999)) optimizerD = optim.Adam(list(globalD.parameters()) + list(localD.parameters()), lr=0.001, betas=(0.9, 0.999)) schedulerG = optim.lr_scheduler.StepLR(optimizerG, step_size=20, gamma=0.5) schedulerD = optim.lr_scheduler.StepLR(optimizerD, step_size=20, gamma=0.5) netG.to(device) localD.to(device) globalD.to(device) criterion = distChamfer Dcriterion = nn.BCELoss() #Dcriterion = F.nll_loss real_label = 1 fake_label = 0 num_batch = len(dataset) / opt.batchsize writer = SummaryWriter() for epoch in range(opt.nepoch): for i, data in (enumerate(dataloader, 0)): #k = 614 points, target, mask = data # Nx4 or Nx3 points = points.transpose(2, 1) # 4xN points, target = points.to(device, dtype=torch.float), target.to( device, dtype=torch.float) b_size = points.shape[0] mask_ = mask.unsqueeze(2).repeat(1, 1, 3) #print(mask_.any(dim = 2).sum(dim=1)) mask__ = ~mask_ #print(mask__.any(dim = 2).sum(dim=1)) mask__ = mask__.to(device, dtype=torch.float32) mask_ = mask_.to(device, dtype=torch.float32) optimizerD.zero_grad() localD = localD.train() globalD = globalD.train() ###### train D ###### #label_real = torch.stack((torch.zeros(b_size),torch.ones(b_size)), dim = 1).to(device, dtype = torch.long) #label_fake = torch.stack((torch.ones(b_size),torch.zeros(b_size)), dim = 1).to(device, dtype = torch.long) label = torch.full((b_size, ), real_label, device=device) #print(mask__) #print(mask__[mask__.sum(dim=2) != 0].shape) target_mask = mask__ * target target_mask = target_mask[torch.abs(target_mask).sum( dim=2) != 0].view(b_size, -1, 3) target, target_mask = target.transpose( 2, 1).contiguous(), target_mask.transpose(2, 1).contiguous() output_g = globalD(target) output_l = localD(target_mask) #rint(output_g.shape) #rint(output_l.shape) #rint(label.shape) errD_real_g = Dcriterion(output_g, label) errD_real_l = Dcriterion(output_l, label) errD_real = errD_real_g + errD_real_l errD_real.backward() target = target.transpose(2, 1).contiguous() pred = netG(points) #rint(pred.shape) ##int(target.shape) #rint(mask_.shape) #rint(mask__.shape) pred = (pred * mask__) + (target * mask_) pred_mask = pred * mask__ pred_mask = pred_mask[torch.abs(pred_mask).sum(dim=2) != 0].view( b_size, -1, 3) pred, pred_mask = pred.transpose( 2, 1).contiguous(), pred_mask.transpose(2, 1).contiguous() output_g = globalD(pred.detach()) output_l = localD(pred_mask.detach()) label.fill_(fake_label) errD_fake_g = Dcriterion(output_g, label) errD_fake_l = Dcriterion(output_l, label) errD_fake = errD_fake_g + errD_fake_l errD_fake.backward() errD = errD_real + errD_fake if errD.item() > 0.1: optimizerD.step() ###### train G ###### optimizerG.zero_grad() optimizerD.zero_grad() netG = netG.train() output_g = globalD(pred) output_l = localD(pred_mask) label.fill_(real_label) errG_g = Dcriterion(output_g, label) errG_l = Dcriterion(output_l, label) errG = errG_g + errG_l pred = pred.transpose(2, 1).contiguous() #rint(pred.shape) #rint(target.shape) dist1, dist2 = criterion(pred, target) chamferloss = (torch.mean(dist1)) + (torch.mean(dist2)) loss = chamferloss + errG loss.backward() if opt.feature_transform: loss += feature_transform_regularizer(trans_feat) * 0.001 optimizerG.step() print('[%d: %d/%d] D_loss: %f, G_loss: %f, Chamfer_loss: %f ' % (epoch, i, num_batch, errD.item(), errG.item(), chamferloss.item())) if i % 10 == 0: j, data = next(enumerate(testdataloader, 0)) points, target, mask = data points = points.transpose(2, 1) points, target = points.to( device, dtype=torch.float), target.to(device, dtype=torch.float) b_size = points.shape[0] localD = localD.eval() globalD = globalD.eval() ###### eval D ###### label = torch.full((b_size, ), real_label, device=device) #label_real = torch.stack((torch.zeros(b_size),torch.ones(b_size)), dim = 1).to(device) #label_fake = torch.stack((torch.ones(b_size),torch.zeros(b_size)), dim = 1).to(device) mask_ = mask.unsqueeze(2).repeat(1, 1, 3) mask__ = ~mask_ mask__ = mask__.to(device, dtype=torch.float32) mask_ = mask_.to(device, dtype=torch.float32) target_mask = mask__ * target target_mask = target_mask[torch.abs(target_mask).sum( dim=2) != 0].view(b_size, -1, 3) target, target_mask = target.transpose( 2, 1).contiguous(), target_mask.transpose(2, 1).contiguous() output_g = globalD(target) output_l = localD(target_mask) errD_real_g_eval = Dcriterion(output_g, label) errD_real_l_eval = Dcriterion(output_l, label) errD_real_eval = errD_real_g_eval + errD_real_l_eval target = target.transpose(2, 1).contiguous() pred = netG(points) pred = (pred * mask__) + (target * mask_) pred_mask = pred * mask__ pred_mask = pred_mask[torch.abs(pred_mask).sum( dim=2) != 0].view(b_size, -1, 3) pred, pred_mask = pred.transpose( 2, 1).contiguous(), pred_mask.transpose(2, 1).contiguous() output_g_eval = globalD(pred.detach()) output_l_eval = localD(pred_mask.detach()) label.fill_(fake_label) errD_fake_g_eval = Dcriterion(output_g, label) errD_fake_l_eval = Dcriterion(output_l, label) errD_fake_eval = errD_fake_g_eval + errD_fake_l_eval errD_eval = errD_real_eval + errD_fake_eval ###### eval G ###### netG = netG.eval() output_g = globalD(pred) output_l = localD(pred_mask) label.fill_(real_label) errG_g_eval = Dcriterion(output_g, label) errG_l_eval = Dcriterion(output_l, label) errG_eval = errG_g_eval + errG_l_eval pred = pred.transpose(2, 1).contiguous() dist1, dist2 = criterion(pred, target) chamferloss_eval = (torch.mean(dist1)) + (torch.mean(dist2)) loss_eval = chamferloss_eval + errG_eval print('[%d: %d/%d] %s D_loss: %f, G_loss: %f ' % (epoch, i, num_batch, blue('test'), errD_eval.item(), loss.item())) if i % 100 == 0: n = int(i / 100) writer.add_scalar('errD_real', errD_real.item(), 27 * epoch + n) writer.add_scalar('errD_fake', errD_fake.item(), 27 * epoch + n) writer.add_scalar('errD_loss', errD.item(), 27 * epoch + n) writer.add_scalar('validation errD_real', errD_real_eval.item(), 27 * epoch + n) writer.add_scalar('validation errD_fake', errD_fake_eval.item(), 27 * epoch + n) writer.add_scalar('validation errD_loss', errD_eval.item(), 27 * epoch + n) writer.add_scalar('errG_global', errG_g.item(), 27 * epoch + n) writer.add_scalar('errG_local', errG_l.item(), 27 * epoch + n) writer.add_scalar('chamfer_loss', chamferloss.item(), 27 * epoch + n) writer.add_scalar('errG_loss', loss.item(), 27 * epoch + n) writer.add_scalar('validation errG_global', errG_g_eval.item(), 27 * epoch + n) writer.add_scalar('validation errG_local', errG_l_eval.item(), 27 * epoch + n) writer.add_scalar('validation chamfer_loss', chamferloss_eval.item(), 27 * epoch + n) writer.add_scalar('validation errG_loss', loss_eval.item(), 27 * epoch + n) for name, param in globalD.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), 27 * epoch + n) for name, param in localD.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), 27 * epoch + n) for name, param in netG.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), 27 * epoch + n) schedulerG.step() schedulerD.step() #torch.save(netG.state_dict(), '%s/com_model_G_%f_%d.pth' % (opt.outf, loss.item(), epoch)) #torch.save(localD.state_dict(), '%s/com_model_localD_%f_%d.pth' % (opt.outf, errD.item(), epoch)) #torch.save(globalD.state_dict(), '%s/com_model_globalD_%f_%d.pth' % (opt.outf, errD.item(), epoch)) return errD.item(), errG, chamferloss
def our_main(): from utils.show3d_balls import showpoints parser = argparse.ArgumentParser() parser.add_argument( '--batchSize', type=int, default=32, help='input batch size') parser.add_argument( '--num_points', type=int, default=2000, help='input batch size') parser.add_argument( '--workers', type=int, help='number of data loading workers', default=4) parser.add_argument( '--nepoch', type=int, default=250, help='number of epochs to train for') parser.add_argument('--outf', type=str, default='cls', help='output folder') parser.add_argument('--model', type=str, default='', help='model path') parser.add_argument('--dataset', type=str, required=True, help="dataset path") parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40") parser.add_argument('--feature_transform', action='store_true', help="use feature transform") opt = parser.parse_args() print(opt) blue = lambda x: '\033[94m' + x + '\033[0m' opt.manualSeed = random.randint(1, 10000) # fix seed print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) if opt.dataset_type == 'shapenet': dataset = ShapeNetDataset( root=opt.dataset, classification=True, npoints=opt.num_points) test_dataset = ShapeNetDataset( root=opt.dataset, classification=True, split='test', npoints=opt.num_points, data_augmentation=False) elif opt.dataset_type == 'modelnet40': dataset = ModelNetDataset( root=opt.dataset, npoints=opt.num_points, split='trainval') test_dataset = ModelNetDataset( root=opt.dataset, split='test', npoints=opt.num_points, data_augmentation=False) else: exit('wrong dataset type') dataloader = torch.utils.data.DataLoader( dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) testdataloader = torch.utils.data.DataLoader( test_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) print(len(dataset), len(test_dataset)) num_classes = len(dataset.classes) print('classes', num_classes) try: os.makedirs(opt.outf) except OSError: pass classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform) if opt.model != '': classifier.load_state_dict(torch.load(opt.model)) optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) classifier.cuda() num_batch = len(dataset) / opt.batchSize ## python train_classification.py --dataset ../dataset --nepoch=4 --dataset_type shapenet for epoch in range(opt.nepoch): scheduler.step() for i, data in enumerate(dataloader, 0): points, target = data target = target[:, 0] showpoints(points[0].numpy()) points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() pred, trans, trans_feat = classifier(points) loss = F.nll_loss(pred, target) if opt.feature_transform: loss += feature_transform_regularizer(trans_feat) * 0.001 loss.backward() optimizer.step() pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize))) if i % 10 == 0: j, data = next(enumerate(testdataloader, 0)) points, target = data target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() pred, _, _ = classifier(points) loss = F.nll_loss(pred, target) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) total_correct = 0 total_testset = 0 for i,data in tqdm(enumerate(testdataloader, 0)): points, target = data target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() pred, _, _ = classifier(points) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() total_correct += correct.item() total_testset += points.size()[0] print("final accuracy {}".format(total_correct / float(total_testset)))
comp_pre_label_dist[i, 0] = pred_labels_dist_cls[ i, majority_pre_id] comp_pre_label_dist[i, 1] = sum_[ i] - pred_labels_dist_cls[i, majority_pre_id] elif class_id < pred_labels_dist_cls.shape[1]: comp_pre_label_dist[i, 1] = pred_labels_dist_cls[ i, majority_pre_id] comp_pre_label_dist[i, 0] = sum_[ i] - pred_labels_dist_cls[i, majority_pre_id] else: comp_pre_label_dist[i, 0] = 0 comp_pre_label_dist[i, 1] = sum_[i] seg_loss += loss_r(comp_pre_label_dist, comp_gt_label) loss = feature_transform_regularizer( trans) * 0.001 + seg_loss / 2500 + num_obj_loss loss.backward() optimizer.step() # correct = pred_labels.eq(gt_labels.data).cpu().sum() correct = 0 # total = 0 pred_labels_cpu = pred_labels.cpu() gt_labels_cpu = gt_labels.cpu() gt_numobj_cpu, _ = torch.max(gt_labels_cpu, dim=1) pred_numobj_cpu = pred_numobj.cpu() for b in range(batch_size): pred_numobj_ = pred_numobj_cpu[b] pred_labels_ = pred_labels_cpu[b] gt_labels_ = gt_labels_cpu[b].T gt_numobj_ = gt_numobj_cpu[b] # print("1:", pred_labels_.min(), pred_labels_.max())
def PointNetSeg(): classifier = point_net_seg(num_classes, feature_transform=opt.feature_transform) if opt.model != '': classifier.load_state_dict(torch.load(opt.model)) # if you have trained model params optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) classifier.cuda() # model load to gpu num_batch = len(dataset) / opt.batch_size # save loss and acc: train_loss = {} test_loss = {} train_acc = {} test_acc = {} for epoch in range(opt.nepoch): for i, data in enumerate(dataloader, 0): points, target = data points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() try: pred, trans, trans_feat = classifier(points) except RuntimeError as exception: if "out of memory" in str(exception): print("WARNING: out of memory") if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception # pred, trans, trans_feat = classifier(points) pred = pred.view(-1, num_classes) # [B*N, k] target = target.view(-1, 1)[:, 0] - 1 # ShapeNet's label is from 1 to k loss = F.nll_loss(pred, target) # -x[class] if opt.feature_transform: loss += feature_transform_regularizer(trans_feat) * 0.001 loss.backward() optimizer.step() pred_seg = pred.data.max(1)[1] # [B*N, k]->>[B*K, 1], max() return [values(probabilities), indices] correct = pred_seg.eq(target.data).cpu().sum() print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item()/float(opt.batch_size * 2500))) # add initial train loss and acc in first epoch: if epoch == 0 and i == 0: train_loss[epoch] = loss.item() train_acc[epoch] = correct.item() / float(opt.batch_size * 2500) if i%10 == 0: # add train loss and acc in each epoch: if i+10 > num_batch: train_loss[epoch+1] = loss.item() train_acc[epoch+1] = correct.item() / float(opt.batch_size * 2500) j, data = next(enumerate(test_dataloader, 0)) points, target = data points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() pred, _, _ = classifier(points) pred = pred.view(-1, num_classes) target = target.view(-1, 1)[:, 0] - 1 loss = F.nll_loss(pred, target) pred_seg = pred.data.max(1)[1] correct = pred_seg.eq(target.data).cpu().sum() print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item() / float(opt.batch_size * 2500))) # add initial test loss and acc in first epoch: if epoch == 0 and i == 0: test_loss[epoch] = loss.item() test_acc[epoch] = correct.item() / float(opt.batch_size * 2500) # add test loss and acc in each epoch: if i+10 > num_batch: test_loss[epoch+1] = loss.item() test_acc[epoch+1] = correct.item() / float(opt.batch_size * 2500) scheduler.step() # save checkpoint every epoch: torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch)) # benchmark mIOU: shape_ious = [] for i, data in tqdm(enumerate(test_dataloader, 0)): points, target = data points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() pred, _, _ = classifier(points) pred_seg = pred.data.max(2)[1] # [B, N, k]->> [B, N, 1], [0]-max value, [1]-indices pred_np = pred_seg.cpu().data.numpy() # [B, N, 1] target_np = target.cpu().data.numpy() - 1 for shape_idx in range(target_np.shape[0]): parts = range(num_classes) part_ious = [] for part in parts: I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part)) U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part)) if U == 0: iou = 1 # #If the union of groundtruth and prediction points is empty, then count part IoU as 1 else: iou = I / U part_ious.append(iou) shape_ious.append(np.mean(part_ious)) print('mIOU for class {}: {}'.format(opt.class_choice, np.mean(shape_ious))) return train_loss, test_loss, train_acc, test_acc
def PointNetCls(): classifier = point_net_cls(k=num_classes, feature_transform=opt.feature_transform) # default is False! if opt.model != '': classifier.load_state_dict(torch.load(opt.model)) # if you have trained model params optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) classifier.cuda() # model load to gpu num_batch = len(dataset) / opt.batch_size # save loss and acc: train_loss = {} test_loss = {} train_acc = {} test_acc = {} for epoch in range(opt.nepoch): # scheduler.step() for i, data in enumerate(dataloader, 0): points, target = data target = target[:, 0] # [B, 1]->>size([B]) # print(target.shape) points = points.transpose(2, 1) # [B=32, 3, N] points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() # for training mode # pred, trans, trans_feat = classifier(points) try: pred, trans, trans_feat = classifier(points) except RuntimeError as exception: if "out of memory" in str(exception): print("WARNING: out of memory") if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception loss = F.nll_loss(pred, target) if opt.feature_transform: loss += feature_transform_regularizer(trans_feat) * 0.01 loss.backward() optimizer.step() pred_cls = pred.data.max(1)[1] # [B, k]->>[B, 1] correct = pred_cls.eq(target.data).cpu().sum() # num of correct predict in batch_i print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batch_size))) # add initial train loss and acc in first epoch: if epoch == 0 and i == 0: train_loss[epoch] = loss.item() train_acc[epoch] = correct.item() / float(opt.batch_size) # show acc in one batch test_data every 10 batch_size: if i % 10 == 0: # add train loss and acc in each epoch: if i+10 > num_batch: train_loss[epoch+1] = loss.item() train_acc[epoch+1] = correct.item() / float(opt.batch_size) j, data = next(enumerate(test_dataloader, 0)) points, target = data target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() # for evaluation mode pred, _, _ = classifier(points) loss = F.nll_loss(pred, target) pred_cls = pred.data.max(1)[1] correct = pred_cls.eq(target.data).cpu().sum() print(correct.item(), opt.batch_size) print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item() / float(opt.batch_size))) # add initial test loss and acc in first epoch: if epoch == 0 and i == 0: test_loss[epoch] = loss.item() test_acc[epoch] = correct.item() / float(opt.batch_size) # add test loss and acc in each epoch: if i+10 > num_batch: test_loss[epoch+1] = loss.item() test_acc[epoch+1] = correct.item() / float(opt.batch_size) scheduler.step() # save checkpoint every epoch: torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) # default: 'cls/cls_model_0.pth' # calculate acc on whole test dataset: total_correct = 0 total_testset = 0 for i, data in tqdm(enumerate(test_dataloader, 0)): points, target = data target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() pred, _, _ = classifier(points) pred_cls = pred.data.max(1)[1] correct = pred_cls.eq(target.data).cpu().sum() total_correct += correct.item() total_testset += points.size()[0] # add batch_size print('final accuracy {}'.format(total_correct / float(total_testset))) return train_loss, test_loss, train_acc, test_acc