Exemple #1
0
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np

from models.autoencoder import Model
from data.modelnet_shrec_loader import ModelNet_Shrec_Loader
from data.shapenet_loader import ShapeNetLoader
from util.visualizer import Visualizer

if __name__ == '__main__':
    if opt.dataset == 'modelnet' or opt.dataset == 'shrec':
        trainset = ModelNet_Shrec_Loader(opt.dataroot, 'train', opt)
        dataset_size = len(trainset)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=opt.batch_size,
                                                  shuffle=True,
                                                  num_workers=opt.nThreads)
        print('#training point clouds = %d' % len(trainset))

        testset = ModelNet_Shrec_Loader(opt.dataroot, 'test', opt)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=opt.batch_size,
                                                 shuffle=False,
                                                 num_workers=opt.nThreads)
    elif opt.dataset == 'shapenet':
        trainset = ShapeNetLoader(opt.dataroot, 'train', opt)
        dataset_size = len(trainset)
Exemple #2
0
def train(model, config):

    trainset = ModelNet_Shrec_Loader(
        os.path.join(config.data, 'train_files.txt'), 'train', config.data,
        config)
    dataset_size = len(trainset)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_threads)
    print('#training point clouds = %d' % len(trainset))

    start_epoch = 0
    WEIGHTS = config.weights
    if WEIGHTS != -1:
        ld = config.log_dir
        start_epoch = WEIGHTS + 1
        ACC_LOGGER.load(
            (os.path.join(ld, "{}_acc_train_accuracy.csv".format(config.name)),
             os.path.join(ld, "{}_acc_eval_accuracy.csv".format(config.name))),
            epoch=WEIGHTS)
        LOSS_LOGGER.load(
            (os.path.join(ld, "{}_loss_train_loss.csv".format(config.name)),
             os.path.join(ld, '{}_loss_eval_loss.csv'.format(config.name))),
            epoch=WEIGHTS)

    print("Starting training")
    best_accuracy = 0
    losses = []
    accs = []
    if config.num_classes == 10:
        config.dropout = config.dropout + 0.1

    begin = start_epoch
    end = config.max_epoch + start_epoch
    for epoch in range(begin, end + 1):
        epoch_iter = 0
        for i, data in enumerate(trainloader):
            epoch_iter += config.batch_size

            input_pc, input_sn, input_label, input_node, input_node_knn_I = data
            model.set_input(input_pc, input_sn, input_label, input_node,
                            input_node_knn_I)

            model.optimize(epoch=epoch)
            errors = model.get_current_errors()
            losses.append(errors['train_loss'])
            accs.append(errors['train_accuracy'])

            if i % max(config.train_log_frq / config.batch_size, 1) == 0:
                acc = np.mean(accs)
                loss = np.mean(losses)
                LOSS_LOGGER.log(loss, epoch, "train_loss")
                ACC_LOGGER.log(acc, epoch, "train_accuracy")
                print("EPOCH {} acc: {} loss: {}".format(epoch, acc, loss))
                ACC_LOGGER.save(config.log_dir)
                LOSS_LOGGER.save(config.log_dir)
                ACC_LOGGER.plot(dest=config.log_dir)
                LOSS_LOGGER.plot(dest=config.log_dir)
                losses = []
                accs = []

        best_accuracy = test(model,
                             config,
                             best_accuracy=best_accuracy,
                             epoch=epoch)

        if epoch % config.save_each == 0 or epoch == end:
            print("Saving network...")
            save_path = os.path.join(
                config.log_dir,
                config.snapshot_prefix + '_encoder_' + str(epoch))
            model.save_network(model.encoder, save_path, 0)
            save_path = os.path.join(
                config.log_dir,
                config.snapshot_prefix + '_classifier_' + str(epoch))
            model.save_network(model.classifier, save_path, 0)

        if epoch % config.lr_decay_step == 0 and epoch > 0:
            model.update_learning_rate(0.5)
        # batch normalization momentum decay:
        next_epoch = epoch + 1
        if (config.bn_momentum_decay_step
                is not None) and (next_epoch >= 1) and (
                    next_epoch % config.bn_momentum_decay_step == 0):
            current_bn_momentum = config.bn_momentum * (
                config.bn_momentum_decay
                **(next_epoch // config.bn_momentum_decay_step))
            print('BN momentum updated to: %f' % current_bn_momentum)
Exemple #3
0
def main():
    while opt.batch_size * opt.rot_equivariant_no * opt.input_pc_num > 8*12*1024:
        opt.batch_size = round(opt.batch_size / 2)
    print('batch_size %d ' % opt.batch_size)

    trainset = ModelNet_Shrec_Loader(opt.dataroot, 'train', opt)
    dataset_size = len(trainset)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True,
                                              num_workers=opt.nThreads)

    testset = ModelNet_Shrec_Loader(opt.dataroot, 'test', opt)
    testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False,
                                             num_workers=opt.nThreads)
    # create model, optionally load pre-trained model
    model_path = 'net_gpu_2'
    print(model_path)
    model = Model(opt)
    model.encoder.load_state_dict(
        model_state_dict_parallel_convert(torch.load(
            model_path+'_encoder.pth', map_location='cpu'),
            mode='same'))
    model.classifier.load_state_dict(
        model_state_dict_parallel_convert(torch.load(
            model_path+'_classifier.pth', map_location='cpu'),
            mode='same'))

    visualizer = Visualizer(opt)

    # test network
    batch_amount = 0
    model.test_loss.data.zero_()
    model.test_accuracy.data.zero_()

    per_class_correct = np.zeros(opt.classes)
    per_class_amount = np.zeros(opt.classes)
    per_class_acc = np.zeros(opt.classes)

    softmax = torch.nn.Softmax(dim=1).to(opt.device)

    voting_num = 12
    for i, data in enumerate(testloader):
        B = data[0].size()[0]
        C = opt.classes

        input_pc, input_sn, input_label, input_node, input_node_knn_I = data

        # perform voting
        score_sum = torch.zeros((B, C), dtype=torch.float32, device=opt.device, requires_grad=False)  # BxC
        loss_sum = torch.tensor([0], dtype=torch.float32, device=opt.device, requires_grad=False)
        for v in range(voting_num):
            if opt.rot_equivariant_mode == '2d':
                angle = (2 * math.pi / voting_num) * v
                rot_input_pc, rot_input_sn, rot_input_node = augmentation.rotate_point_cloud_with_normal_som_pytorch_batch(
                    input_pc,
                    input_sn,
                    input_node,
                    angle)
            elif opt.rot_equivariant_mode == '3d':
                rot_input_pc, rot_input_sn, rot_input_node = augmentation.rotate_point_cloud_with_normal_som_pytorch_batch_3d(
                    input_pc,
                    input_sn,
                    input_node)
            else:
                raise Exception('wrong mode.')


            model.set_input(rot_input_pc, rot_input_sn, input_label, rot_input_node, input_node_knn_I)
            model.test_model()

            # accumulate score
            score_sum += softmax(model.score.detach())
            # score_sum += model.score.detach()
            loss_sum += model.loss.detach()

        # calculate voted score/prediction
        batch_amount += B

        # accumulate loss
        model.test_loss += (loss_sum / voting_num) * B

        # accumulate accuracy
        _, predicted_idx = torch.max(score_sum, dim=1, keepdim=False)
        correct_mask = torch.eq(predicted_idx, model.label).float()
        test_accuracy = torch.mean(correct_mask).cpu()
        model.test_accuracy += test_accuracy * B

        # per class accuracy
        for b in range(model.label.size()[0]):  # tensor
            per_class_amount[model.label[b]] += 1
            if correct_mask[b] >= 0.9:
                per_class_correct[model.label[b]] += 1

    model.test_loss /= batch_amount
    model.test_accuracy /= batch_amount

    print('test sample number %d' % batch_amount)
    print('Loss %f, accuracy %f' % (model.test_loss.item(), model.test_accuracy.item()))

    # per class accuracy
    per_class_acc = per_class_correct / per_class_amount
    print('Per class accuracy: %f' % np.mean(per_class_acc))

    return model.test_accuracy.item(), np.mean(per_class_acc)
Exemple #4
0
        loss = model.test_loss.item()
        acc = model.test_accuracy.item()
        print('Tested network. So far best: %f' % best_accuracy)
        print("TESTING EPOCH {} acc: {} loss: {}".format(epoch, acc, loss))
        LOSS_LOGGER.log(loss, epoch, "eval_loss")
        ACC_LOGGER.log(acc, epoch, "eval_accuracy")
        return best_accuracy


if __name__ == '__main__':

    config = get_config()
    LOSS_LOGGER = Logger("{}_loss".format(config.name))
    ACC_LOGGER = Logger("{}_acc".format(config.name))
    testset = ModelNet_Shrec_Loader(
        os.path.join(config.data, 'test_files.txt'), 'test', config.data,
        config)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=config.batch_size,
                                             shuffle=False,
                                             num_workers=config.num_threads)

    if not config.test:
        model = Model(config)
        if config.weights != -1:
            weights = os.path.join(
                config.log_dir,
                config.snapshot_prefix + '_encoder_' + str(config.weights))
            model.encoder.load_state_dict(torch.load(weights))
            weights = os.path.join(
                config.log_dir,