def main(log_dir, model_path, decay, data_dir, dataset, partition, batch_size, pretrain, learning_rate, num_workers, epochs, feat, rand_rot, image_shape, base_order, sample_order): arguments = copy.deepcopy(locals()) # Create logging directory if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) shutil.copy2(__file__, os.path.join(log_dir, 'script.py')) shutil.copy2(model_path, os.path.join(log_dir, 'model.py')) # Set up logger logger = logging.getLogger('train') logger.setLevel(logging.DEBUG) logger.handlers = [] ch = logging.StreamHandler() logger.addHandler(ch) fh = logging.FileHandler(os.path.join(log_dir, 'log.txt')) logger.addHandler(fh) logger.info('%s', repr(arguments)) # Speed up convolutions using cuDNN torch.backends.cudnn.benchmark = True # Load the model loader = importlib.machinery.SourceFileLoader( 'model', os.path.join(log_dir, 'model.py')) mod = types.ModuleType(loader.name) loader.exec_module(mod) num_classes = int(dataset[-2:]) model = mod.Model(num_classes, feat=feat) model = nn.DataParallel(model) model = model.cuda() if pretrain: pretrained_dict = torch.load(pretrain) load_partial_model(model, pretrained_dict) logger.info('{} parameters in total'.format( sum(x.numel() for x in model.parameters()))) logger.info('{} parameters in the last layer'.format( sum(x.numel() for x in model.module.out_layer.parameters()))) # Load the dataset # Increasing `repeat` will generate more cached files transform = CacheNPY(prefix='sp{}_'.format(sample_order), transform=torchvision.transforms.Compose([ ToMesh(random_rotations=rand_rot, random_translation=0), ProjectOnSphere(dataset=dataset, image_shape=image_shape, normalize=True) ])) transform_test = CacheNPY(prefix='sp{}_'.format(sample_order), transform=torchvision.transforms.Compose([ ToMesh(random_rotations=False, random_translation=0), ProjectOnSphere(dataset=dataset, image_shape=image_shape, normalize=True) ])) if dataset == 'modelnet10': def target_transform(x): classes = [ 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet' ] return classes.index(x) elif dataset == 'modelnet40': def target_transform(x): classes = [ 'airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa', 'tv_stand', 'bathtub', 'car', 'door', 'lamp', 'piano', 'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant', 'stool', 'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel', 'radio', 'table', 'xbox', 'bookshelf', 'cup', 'glass_box', 'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar', 'night_stand', 'sink', 'toilet' ] return classes.index(x) else: print('invalid dataset. must be modelnet10 or modelnet40') assert (0) train_set = ModelNet(data_dir, image_shape=image_shape, base_order=base_order, sample_order=sample_order, dataset=dataset, partition='train', transform=transform, target_transform=target_transform) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) test_set = ModelNet(data_dir, image_shape=image_shape, base_order=base_order, sample_order=sample_order, dataset=dataset, partition='test', transform=transform_test, target_transform=target_transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=False) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if decay: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.7) def train_step(data, target): model.train() data, target = data.cuda(), target.cuda() prediction = model(data) loss = F.nll_loss(prediction, target) optimizer.zero_grad() loss.backward() optimizer.step() correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() return loss.item(), correct.item() def test_step(data, target): model.eval() data, target = data.cuda(), target.cuda() prediction = model(data) loss = F.nll_loss(prediction, target) correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() return loss.item(), correct.item() def get_learning_rate(epoch): limits = [100, 200] lrs = [1, 0.1, 0.01] assert len(lrs) == len(limits) + 1 for lim, lr in zip(limits, lrs): if epoch < lim: return lr * learning_rate return lrs[-1] * learning_rate best_acc = 0.0 for epoch in range(epochs): if decay: scheduler.step() # training total_loss = 0 total_correct = 0 time_before_load = time.perf_counter() for batch_idx, (data, target) in enumerate(train_loader): time_after_load = time.perf_counter() time_before_step = time.perf_counter() loss, correct = train_step(data, target) total_loss += loss total_correct += correct logger.info( '[{}:{}/{}] LOSS={:.2} <LOSS>={:.2} ACC={:.2} <ACC>={:.2} time={:.2}+{:.2}' .format(epoch, batch_idx, len(train_loader), loss, total_loss / (batch_idx + 1), correct / len(data), total_correct / len(data) / (batch_idx + 1), time_after_load - time_before_load, time.perf_counter() - time_before_step)) time_before_load = time.perf_counter() # test total_loss = 0 total_correct = 0 count = 0 for batch_idx, (data, target) in enumerate(test_loader): loss, correct = test_step(data, target) total_loss += loss total_correct += correct count += 1 acc = total_correct / len(test_set) logger.info('[Epoch {} Test] <LOSS>={:.2} <ACC>={:2}'.format( epoch, total_loss / (count + 1), acc)) # save the state state_dict_no_sparse = [ it for it in model.state_dict().items() if it[1].type() != "torch.cuda.sparse.FloatTensor" ] state_dict_no_sparse = OrderedDict(state_dict_no_sparse) torch.save(state_dict_no_sparse, os.path.join(log_dir, "state.pkl")) # save the best model if acc > best_acc: shutil.copy2(os.path.join(log_dir, "state.pkl"), os.path.join(log_dir, "best.pkl")) best_acc = acc
def main(checkpoint_path, data_dir, dataset, partition, batch_size, feat, num_workers, image_shape, base_order, sample_order): torch.backends.cudnn.benchmark = True # Load the model loader = importlib.machinery.SourceFileLoader('model', "model.py") mod = types.ModuleType(loader.name) loader.exec_module(mod) num_classes = int(dataset[-2:]) model = mod.Model(num_classes, feat=feat) model = nn.DataParallel(model) model = model.cuda() # load checkpoint ckpt = checkpoint_path pretrained_dict = torch.load(ckpt) load_partial_model(model, pretrained_dict) print("{} parameters in total".format( sum(x.numel() for x in model.parameters()))) print("{} parameters in the last layer".format( sum(x.numel() for x in model.module.out_layer.parameters()))) # Load the dataset # Increasing `repeat` will generate more cached files transform = CacheNPY(prefix='sp{}_'.format(sample_order), transform=torchvision.transforms.Compose([ ToMesh(random_rotations=False, random_translation=0), ProjectOnSphere(dataset=dataset, image_shape=image_shape, normalize=True) ])) transform_test = CacheNPY(prefix='sp{}_'.format(sample_order), transform=torchvision.transforms.Compose([ ToMesh(random_rotations=False, random_translation=0), ProjectOnSphere(dataset=dataset, image_shape=image_shape, normalize=True) ])) if dataset == 'modelnet10': def target_transform(x): classes = [ 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet' ] return classes.index(x) elif dataset == 'modelnet40': def target_transform(x): classes = [ 'airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa', 'tv_stand', 'bathtub', 'car', 'door', 'lamp', 'piano', 'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant', 'stool', 'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel', 'radio', 'table', 'xbox', 'bookshelf', 'cup', 'glass_box', 'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar', 'night_stand', 'sink', 'toilet' ] return classes.index(x) else: print('invalid dataset. must be modelnet10 or modelnet40') assert (0) test_set = ModelNet(data_dir, image_shape=image_shape, base_order=base_order, sample_order=sample_order, dataset=dataset, partition='test', transform=transform_test, target_transform=target_transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=False) def test_step(data, target): model.eval() data, target = data.cuda(), target.cuda() prediction = model(data) loss = F.nll_loss(prediction, target) correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() return loss.item(), correct.item() # test total_loss = 0 total_correct = 0 count = 0 for batch_idx, (data, target) in enumerate(test_loader): loss, correct = test_step(data, target) total_loss += loss total_correct += correct count += 1 print("[Test] <LOSS>={:.2} <ACC>={:2}".format( total_loss / (count + 1), total_correct / len(test_set)))
def main(sp_mesh_dir, sp_mesh_level, log_dir, model_path, augmentation, decay, data_dir, tiny, dataset, partition, batch_size, learning_rate, num_workers, epochs, pretrain, feat, rand_rot): arguments = copy.deepcopy(locals()) sp_mesh_file = os.path.join(sp_mesh_dir, "icosphere_{}.pkl".format(sp_mesh_level)) if not os.path.exists(log_dir): os.makedirs(log_dir) shutil.copy2(__file__, os.path.join(log_dir, "script.py")) shutil.copy2(model_path, os.path.join(log_dir, "model.py")) logger = logging.getLogger("train") logger.setLevel(logging.DEBUG) logger.handlers = [] ch = logging.StreamHandler() logger.addHandler(ch) fh = logging.FileHandler(os.path.join(log_dir, "log.txt")) logger.addHandler(fh) logger.info("%s", repr(arguments)) torch.backends.cudnn.benchmark = True # Load the model loader = importlib.machinery.SourceFileLoader( 'model', os.path.join(log_dir, "model.py")) mod = types.ModuleType(loader.name) loader.exec_module(mod) num_classes = int(dataset[-2:]) if tiny: model = mod.Model_tiny(num_classes, mesh_folder=sp_mesh_dir, feat=feat) else: model = mod.Model(num_classes, mesh_folder=sp_mesh_dir, feat=feat) model = nn.DataParallel(model) model.cuda() if pretrain: pretrained_dict = torch.load(pretrain) def load_my_state_dict(self, state_dict, exclude='out_layer'): from torch.nn.parameter import Parameter own_state = self.state_dict() for name, param in state_dict.items(): if name not in own_state: continue if exclude in name: continue if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data own_state[name].copy_(param) load_my_state_dict(model, pretrained_dict) logger.info("{} paramerters in total".format( sum(x.numel() for x in model.parameters()))) logger.info("{} paramerters in the last layer".format( sum(x.numel() for x in model.module.out_layer.parameters()))) # Load the dataset # Increasing `repeat` will generate more cached files transform = CacheNPY(prefix="sp{}_".format(sp_mesh_level), transform=torchvision.transforms.Compose([ ToMesh(random_rotations=False, random_translation=0), ProjectOnSphere(meshfile=sp_mesh_file, dataset=dataset, normalize=True) ]), sp_mesh_dir=sp_mesh_dir, sp_mesh_level=sp_mesh_level) transform_test = CacheNPY(prefix="sp{}_".format(sp_mesh_level), transform=torchvision.transforms.Compose([ ToMesh(random_rotations=False, random_translation=0), ProjectOnSphere(meshfile=sp_mesh_file, dataset=dataset, normalize=True) ]), sp_mesh_dir=sp_mesh_dir, sp_mesh_level=sp_mesh_level) if dataset == 'modelnet10': def target_transform(x): classes = [ 'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet' ] return classes.index(x) elif dataset == 'modelnet40': def target_transform(x): classes = [ 'airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa', 'tv_stand', 'bathtub', 'car', 'door', 'lamp', 'piano', 'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant', 'stool', 'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel', 'radio', 'table', 'xbox', 'bookshelf', 'cup', 'glass_box', 'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar', 'night_stand', 'sink', 'toilet' ] return classes.index(x) else: print('invalid dataset. must be modelnet10 or modelnet40') assert (0) train_set = ModelNet(data_dir, dataset=dataset, partition='train', transform=transform, target_transform=target_transform) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) test_set = ModelNet(data_dir, dataset=dataset, partition='test', transform=transform_test, target_transform=target_transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=False) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if decay: scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.7) def train_step(data, target): model.train() data, target = data.cuda(), target.cuda() prediction = model(data) loss = F.nll_loss(prediction, target) optimizer.zero_grad() loss.backward() optimizer.step() correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() return loss.item(), correct.item() def test_step(data, target): model.eval() data, target = data.cuda(), target.cuda() prediction = model(data) loss = F.nll_loss(prediction, target) correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() return loss.item(), correct.item() def get_learning_rate(epoch): limits = [100, 200] lrs = [1, 0.1, 0.01] assert len(lrs) == len(limits) + 1 for lim, lr in zip(limits, lrs): if epoch < lim: return lr * learning_rate return lrs[-1] * learning_rate for epoch in range(epochs): if decay: scheduler.step() # training total_loss = 0 total_correct = 0 time_before_load = time.perf_counter() for batch_idx, (data, target) in enumerate(train_loader): time_after_load = time.perf_counter() time_before_step = time.perf_counter() loss, correct = train_step(data, target) total_loss += loss total_correct += correct logger.info( "[{}:{}/{}] LOSS={:.2} <LOSS>={:.2} ACC={:.2} <ACC>={:.2} time={:.2}+{:.2}" .format(epoch, batch_idx, len(train_loader), loss, total_loss / (batch_idx + 1), correct / len(data), total_correct / len(data) / (batch_idx + 1), time_after_load - time_before_load, time.perf_counter() - time_before_step)) time_before_load = time.perf_counter() # test total_loss = 0 total_correct = 0 count = 0 for batch_idx, (data, target) in enumerate(test_loader): loss, correct = test_step(data, target) total_loss += loss total_correct += correct count += 1 logger.info("[Epoch {} Test] <LOSS>={:.2} <ACC>={:2}".format( epoch, total_loss / (count + 1), total_correct / len(test_set))) # remove sparse matrices since they cannot be stored state_dict_no_sparse = [ it for it in model.state_dict().items() if it[1].type() != "torch.cuda.sparse.FloatTensor" ] state_dict_no_sparse = OrderedDict(state_dict_no_sparse) torch.save(state_dict_no_sparse, os.path.join(log_dir, "state.pkl"))
def main(log_dir, model_path, augmentation, dataset, num_cls, few, batch_size, num_workers, learning_rate): arguments = copy.deepcopy(locals()) os.mkdir(log_dir) shutil.copy2(__file__, os.path.join(log_dir, "script.py")) shutil.copy2(model_path, os.path.join(log_dir, "model.py")) shutil.copy2(os.path.join(ROOT, "dataset.py"), os.path.join(log_dir, "dataset.py")) logger = logging.getLogger("train") logger.setLevel(logging.DEBUG) logger.handlers = [] ch = logging.StreamHandler() logger.addHandler(ch) fh = logging.FileHandler(os.path.join(log_dir, "log.txt")) logger.addHandler(fh) logger.info("%s", repr(arguments)) torch.backends.cudnn.benchmark = True # Load the model loader = importlib.machinery.SourceFileLoader( 'model', os.path.join(log_dir, "model.py")) mod = types.ModuleType(loader.name) loader.exec_module(mod) #model = mod.Model(55) model = mod.Model(num_cls) model.cuda() logger.info("{} paramerters in total".format( sum(x.numel() for x in model.parameters()))) logger.info("{} paramerters in the last layer".format( sum(x.numel() for x in model.out_layer.parameters()))) bw = model.bandwidths[0] # Load the dataset # Increasing `repeat` will generate more cached files train_transform = CacheNPY(prefix="b{}_".format(bw), repeat=augmentation, pick_randomly=True, transform=torchvision.transforms.Compose([ ToMesh(random_rotations=True, random_translation=0.1), ProjectOnSphere(bandwidth=bw) ])) # test_transform = torchvision.transforms.Compose([ # CacheNPY(prefix="b64_", repeat=augmentation, pick_randomly=False, transform=torchvision.transforms.Compose( # [ # ToMesh(random_rotations=True, random_translation=0.1), # ProjectOnSphere(bandwidth=64) # ] # )), # lambda xs: torch.stack([torch.FloatTensor(x) for x in xs]) # ]) test_transform = train_transform if "10" in dataset: train_data_type = "test" test_data_type = "train" else: train_data_type = "train" test_data_type = "test" train_set = ModelNet("/home/lixin/Documents/s2cnn/ModelNet", dataset, train_data_type, few=few, transform=train_transform) if few: train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=False) else: train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) test_set = ModelNet("/home/lixin/Documents/s2cnn/ModelNet", dataset, test_data_type, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=False) optimizer = torch.optim.SGD(model.parameters(), lr=0, momentum=0.9) def train_step(data, target): model.train() data, target = data.cuda(), target.cuda() prediction = model(data) loss = F.nll_loss(prediction, target) optimizer.zero_grad() loss.backward() optimizer.step() correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() return loss.item(), correct.item() def test(epoch): predictions = [] gt = [] for batch_idx, (data, target) in enumerate(test_loader): model.eval() #batch_size, rep = data.size()[:2] #data = data.view(-1, *data.size()[2:]) data, target = data.cuda(), target.cuda() with torch.no_grad(): pred = model(data).data #pred = pred.view(batch_size*rep, -1) #pred = pred.sum(1) predictions.append(pred.cpu().numpy()) #gt.append([target.cpu().numpy()]*rep) gt.append(target.cpu().numpy()) predictions = np.concatenate(predictions) gt = np.concatenate(gt) predictions_class = np.argmax(predictions, axis=1) acc = np.sum(predictions_class == gt) / len(test_set) logger.info("Test Acc: {}".format(acc)) return acc def get_learning_rate(epoch): limits = [100, 200] lrs = [1, 0.1, 0.01] assert len(lrs) == len(limits) + 1 for lim, lr in zip(limits, lrs): if epoch < lim: return lr * learning_rate return lrs[-1] * learning_rate best_acc = 0. for epoch in range(300): lr = get_learning_rate(epoch) logger.info("learning rate = {} and batch size = {}".format( lr, train_loader.batch_size)) for p in optimizer.param_groups: p['lr'] = lr total_loss = 0 total_correct = 0 time_before_load = time.perf_counter() for batch_idx, (data, target) in enumerate(train_loader): time_after_load = time.perf_counter() time_before_step = time.perf_counter() loss, correct = train_step(data, target) total_loss += loss total_correct += correct logger.info( "[{}:{}/{}] LOSS={:.3} <LOSS>={:.3} ACC={:.3} <ACC>={:.3} time={:.2}+{:.2}" .format(epoch, batch_idx, len(train_loader), loss, total_loss / (batch_idx + 1), correct / len(data), total_correct / len(data) / (batch_idx + 1), time_after_load - time_before_load, time.perf_counter() - time_before_step)) time_before_load = time.perf_counter() test_acc = test(epoch) if test_acc > best_acc: best_acc = test_acc torch.save(model.state_dict(), os.path.join(log_dir, "best_state.pkl")) torch.save(model.state_dict(), os.path.join(log_dir, "state.pkl"))
def main(sp_mesh_dir, sp_mesh_level, log_dir, data_dir, eval_time, dataset, partition, batch_size, jobs, tiny, feat, no_cuda, neval): torch.set_num_threads(jobs) print("Running on {} CPU(s)".format(torch.get_num_threads())) if no_cuda: device = torch.device("cpu") else: device = torch.device("cuda") torch.backends.cudnn.benchmark = True sp_mesh_file = os.path.join(sp_mesh_dir, "icosphere_{}.pkl".format(sp_mesh_level)) # Load the model loader = importlib.machinery.SourceFileLoader('model',"model.py") mod = types.ModuleType(loader.name) loader.exec_module(mod) num_classes = int(dataset[-2:]) if not tiny: model = mod.Model(num_classes, mesh_folder=sp_mesh_dir, feat=feat) else: model = mod.Model_tiny(num_classes, mesh_folder=sp_mesh_dir, feat=feat) # load checkpoint ckpt = os.path.join(log_dir, "state.pkl") if no_cuda: pretrained_dict = torch.load(ckpt, map_location=lambda storage, loc:storage) else: pretrained_dict = torch.load(ckpt) def load_my_state_dict(self, state_dict, exclude='out_layer'): from torch.nn.parameter import Parameter own_state = self.state_dict() for name, param in state_dict.items(): if name not in own_state: continue if exclude in name: continue if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data own_state[name].copy_(param) load_my_state_dict(model, pretrained_dict) model.to(device) print("{} paramerters in total".format(sum(x.numel() for x in model.parameters()))) print("{} paramerters in the last layer".format(sum(x.numel() for x in model.out_layer.parameters()))) # Load the dataset # Increasing `repeat` will generate more cached files transform = CacheNPY(prefix="sp{}_".format(sp_mesh_level), transform=torchvision.transforms.Compose( [ ToMesh(random_rotations=False, random_translation=0), ProjectOnSphere(meshfile=sp_mesh_file, dataset=dataset, normalize=True) ] ), sp_mesh_dir=sp_mesh_dir, sp_mesh_level=sp_mesh_level) transform_test = CacheNPY(prefix="sp{}_".format(sp_mesh_level), transform=torchvision.transforms.Compose( [ ToMesh(random_rotations=False, random_translation=0), ProjectOnSphere(meshfile=sp_mesh_file, dataset=dataset, normalize=True) ] ), sp_mesh_dir=sp_mesh_dir, sp_mesh_level=sp_mesh_level) if dataset == 'modelnet10': def target_transform(x): classes = ['bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet'] return classes.index(x) elif dataset == 'modelnet40': def target_transform(x): classes = ['airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa', 'tv_stand', 'bathtub', 'car', 'door', 'lamp', 'piano', 'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant', 'stool', 'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel', 'radio', 'table', 'xbox', 'bookshelf', 'cup', 'glass_box', 'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar', 'night_stand', 'sink', 'toilet'] return classes.index(x) else: print('invalid dataset. must be modelnet10 or modelnet40') assert(0) test_set = ModelNet(data_dir, dataset=dataset, partition='test', transform=transform_test, target_transform=target_transform) test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=jobs, pin_memory=True, drop_last=False) def test_step(data, target): model.eval() data, target = data.to(device), target.to(device) t = time() prediction = model(data) dt = time() - t loss = F.nll_loss(prediction, target) correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() return loss.item(), correct.item(), dt # test total_loss = 0 total_correct = 0 count = 0 total_time = [] for batch_idx, (data, target) in enumerate(test_loader): loss, correct, dt = test_step(data, target) total_time.append(dt) total_loss += loss total_correct += correct count += 1 if eval_time and count >= neval: print("Time per batch: {} secs".format(np.mean(total_time[10:]))) break if not eval_time: print("[Test] <LOSS>={:.2} <ACC>={:2}".format(total_loss / (count+1), total_correct / len(test_set)))