import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import random import numpy as np from models.autoencoder import Model from data.modelnet_shrec_loader import ModelNet_Shrec_Loader from data.shapenet_loader import ShapeNetLoader from util.visualizer import Visualizer if __name__ == '__main__': if opt.dataset == 'modelnet' or opt.dataset == 'shrec': trainset = ModelNet_Shrec_Loader(opt.dataroot, 'train', opt) dataset_size = len(trainset) trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.nThreads) print('#training point clouds = %d' % len(trainset)) testset = ModelNet_Shrec_Loader(opt.dataroot, 'test', opt) testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) elif opt.dataset == 'shapenet': trainset = ShapeNetLoader(opt.dataroot, 'train', opt) dataset_size = len(trainset)
def train(model, config): trainset = ModelNet_Shrec_Loader( os.path.join(config.data, 'train_files.txt'), 'train', config.data, config) dataset_size = len(trainset) trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_threads) print('#training point clouds = %d' % len(trainset)) start_epoch = 0 WEIGHTS = config.weights if WEIGHTS != -1: ld = config.log_dir start_epoch = WEIGHTS + 1 ACC_LOGGER.load( (os.path.join(ld, "{}_acc_train_accuracy.csv".format(config.name)), os.path.join(ld, "{}_acc_eval_accuracy.csv".format(config.name))), epoch=WEIGHTS) LOSS_LOGGER.load( (os.path.join(ld, "{}_loss_train_loss.csv".format(config.name)), os.path.join(ld, '{}_loss_eval_loss.csv'.format(config.name))), epoch=WEIGHTS) print("Starting training") best_accuracy = 0 losses = [] accs = [] if config.num_classes == 10: config.dropout = config.dropout + 0.1 begin = start_epoch end = config.max_epoch + start_epoch for epoch in range(begin, end + 1): epoch_iter = 0 for i, data in enumerate(trainloader): epoch_iter += config.batch_size input_pc, input_sn, input_label, input_node, input_node_knn_I = data model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) model.optimize(epoch=epoch) errors = model.get_current_errors() losses.append(errors['train_loss']) accs.append(errors['train_accuracy']) if i % max(config.train_log_frq / config.batch_size, 1) == 0: acc = np.mean(accs) loss = np.mean(losses) LOSS_LOGGER.log(loss, epoch, "train_loss") ACC_LOGGER.log(acc, epoch, "train_accuracy") print("EPOCH {} acc: {} loss: {}".format(epoch, acc, loss)) ACC_LOGGER.save(config.log_dir) LOSS_LOGGER.save(config.log_dir) ACC_LOGGER.plot(dest=config.log_dir) LOSS_LOGGER.plot(dest=config.log_dir) losses = [] accs = [] best_accuracy = test(model, config, best_accuracy=best_accuracy, epoch=epoch) if epoch % config.save_each == 0 or epoch == end: print("Saving network...") save_path = os.path.join( config.log_dir, config.snapshot_prefix + '_encoder_' + str(epoch)) model.save_network(model.encoder, save_path, 0) save_path = os.path.join( config.log_dir, config.snapshot_prefix + '_classifier_' + str(epoch)) model.save_network(model.classifier, save_path, 0) if epoch % config.lr_decay_step == 0 and epoch > 0: model.update_learning_rate(0.5) # batch normalization momentum decay: next_epoch = epoch + 1 if (config.bn_momentum_decay_step is not None) and (next_epoch >= 1) and ( next_epoch % config.bn_momentum_decay_step == 0): current_bn_momentum = config.bn_momentum * ( config.bn_momentum_decay **(next_epoch // config.bn_momentum_decay_step)) print('BN momentum updated to: %f' % current_bn_momentum)
def main(): while opt.batch_size * opt.rot_equivariant_no * opt.input_pc_num > 8*12*1024: opt.batch_size = round(opt.batch_size / 2) print('batch_size %d ' % opt.batch_size) trainset = ModelNet_Shrec_Loader(opt.dataroot, 'train', opt) dataset_size = len(trainset) trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.nThreads) testset = ModelNet_Shrec_Loader(opt.dataroot, 'test', opt) testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) # create model, optionally load pre-trained model model_path = 'net_gpu_2' print(model_path) model = Model(opt) model.encoder.load_state_dict( model_state_dict_parallel_convert(torch.load( model_path+'_encoder.pth', map_location='cpu'), mode='same')) model.classifier.load_state_dict( model_state_dict_parallel_convert(torch.load( model_path+'_classifier.pth', map_location='cpu'), mode='same')) visualizer = Visualizer(opt) # test network batch_amount = 0 model.test_loss.data.zero_() model.test_accuracy.data.zero_() per_class_correct = np.zeros(opt.classes) per_class_amount = np.zeros(opt.classes) per_class_acc = np.zeros(opt.classes) softmax = torch.nn.Softmax(dim=1).to(opt.device) voting_num = 12 for i, data in enumerate(testloader): B = data[0].size()[0] C = opt.classes input_pc, input_sn, input_label, input_node, input_node_knn_I = data # perform voting score_sum = torch.zeros((B, C), dtype=torch.float32, device=opt.device, requires_grad=False) # BxC loss_sum = torch.tensor([0], dtype=torch.float32, device=opt.device, requires_grad=False) for v in range(voting_num): if opt.rot_equivariant_mode == '2d': angle = (2 * math.pi / voting_num) * v rot_input_pc, rot_input_sn, rot_input_node = augmentation.rotate_point_cloud_with_normal_som_pytorch_batch( input_pc, input_sn, input_node, angle) elif opt.rot_equivariant_mode == '3d': rot_input_pc, rot_input_sn, rot_input_node = augmentation.rotate_point_cloud_with_normal_som_pytorch_batch_3d( input_pc, input_sn, input_node) else: raise Exception('wrong mode.') model.set_input(rot_input_pc, rot_input_sn, input_label, rot_input_node, input_node_knn_I) model.test_model() # accumulate score score_sum += softmax(model.score.detach()) # score_sum += model.score.detach() loss_sum += model.loss.detach() # calculate voted score/prediction batch_amount += B # accumulate loss model.test_loss += (loss_sum / voting_num) * B # accumulate accuracy _, predicted_idx = torch.max(score_sum, dim=1, keepdim=False) correct_mask = torch.eq(predicted_idx, model.label).float() test_accuracy = torch.mean(correct_mask).cpu() model.test_accuracy += test_accuracy * B # per class accuracy for b in range(model.label.size()[0]): # tensor per_class_amount[model.label[b]] += 1 if correct_mask[b] >= 0.9: per_class_correct[model.label[b]] += 1 model.test_loss /= batch_amount model.test_accuracy /= batch_amount print('test sample number %d' % batch_amount) print('Loss %f, accuracy %f' % (model.test_loss.item(), model.test_accuracy.item())) # per class accuracy per_class_acc = per_class_correct / per_class_amount print('Per class accuracy: %f' % np.mean(per_class_acc)) return model.test_accuracy.item(), np.mean(per_class_acc)
loss = model.test_loss.item() acc = model.test_accuracy.item() print('Tested network. So far best: %f' % best_accuracy) print("TESTING EPOCH {} acc: {} loss: {}".format(epoch, acc, loss)) LOSS_LOGGER.log(loss, epoch, "eval_loss") ACC_LOGGER.log(acc, epoch, "eval_accuracy") return best_accuracy if __name__ == '__main__': config = get_config() LOSS_LOGGER = Logger("{}_loss".format(config.name)) ACC_LOGGER = Logger("{}_acc".format(config.name)) testset = ModelNet_Shrec_Loader( os.path.join(config.data, 'test_files.txt'), 'test', config.data, config) testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_threads) if not config.test: model = Model(config) if config.weights != -1: weights = os.path.join( config.log_dir, config.snapshot_prefix + '_encoder_' + str(config.weights)) model.encoder.load_state_dict(torch.load(weights)) weights = os.path.join( config.log_dir,