Exemplo n.º 1
0
def main(log_dir, model_path, decay, data_dir, dataset, partition, batch_size,
         pretrain, learning_rate, num_workers, epochs, feat, rand_rot,
         image_shape, base_order, sample_order):
    arguments = copy.deepcopy(locals())

    # Create logging directory
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    shutil.copy2(__file__, os.path.join(log_dir, 'script.py'))
    shutil.copy2(model_path, os.path.join(log_dir, 'model.py'))

    # Set up logger
    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    ch = logging.StreamHandler()
    logger.addHandler(ch)
    fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
    logger.addHandler(fh)
    logger.info('%s', repr(arguments))

    # Speed up convolutions using cuDNN
    torch.backends.cudnn.benchmark = True

    # Load the model
    loader = importlib.machinery.SourceFileLoader(
        'model', os.path.join(log_dir, 'model.py'))
    mod = types.ModuleType(loader.name)
    loader.exec_module(mod)
    num_classes = int(dataset[-2:])
    model = mod.Model(num_classes, feat=feat)
    model = nn.DataParallel(model)
    model = model.cuda()

    if pretrain:
        pretrained_dict = torch.load(pretrain)
        load_partial_model(model, pretrained_dict)

    logger.info('{} parameters in total'.format(
        sum(x.numel() for x in model.parameters())))
    logger.info('{} parameters in the last layer'.format(
        sum(x.numel() for x in model.module.out_layer.parameters())))

    # Load the dataset
    # Increasing `repeat` will generate more cached files
    transform = CacheNPY(prefix='sp{}_'.format(sample_order),
                         transform=torchvision.transforms.Compose([
                             ToMesh(random_rotations=rand_rot,
                                    random_translation=0),
                             ProjectOnSphere(dataset=dataset,
                                             image_shape=image_shape,
                                             normalize=True)
                         ]))

    transform_test = CacheNPY(prefix='sp{}_'.format(sample_order),
                              transform=torchvision.transforms.Compose([
                                  ToMesh(random_rotations=False,
                                         random_translation=0),
                                  ProjectOnSphere(dataset=dataset,
                                                  image_shape=image_shape,
                                                  normalize=True)
                              ]))

    if dataset == 'modelnet10':

        def target_transform(x):
            classes = [
                'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
                'night_stand', 'sofa', 'table', 'toilet'
            ]
            return classes.index(x)
    elif dataset == 'modelnet40':

        def target_transform(x):
            classes = [
                'airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa',
                'tv_stand', 'bathtub', 'car', 'door', 'lamp', 'piano',
                'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant',
                'stool', 'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel',
                'radio', 'table', 'xbox', 'bookshelf', 'cup', 'glass_box',
                'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar',
                'night_stand', 'sink', 'toilet'
            ]
            return classes.index(x)
    else:
        print('invalid dataset. must be modelnet10 or modelnet40')
        assert (0)

    train_set = ModelNet(data_dir,
                         image_shape=image_shape,
                         base_order=base_order,
                         sample_order=sample_order,
                         dataset=dataset,
                         partition='train',
                         transform=transform,
                         target_transform=target_transform)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    test_set = ModelNet(data_dir,
                        image_shape=image_shape,
                        base_order=base_order,
                        sample_order=sample_order,
                        dataset=dataset,
                        partition='test',
                        transform=transform_test,
                        target_transform=target_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              pin_memory=True,
                                              drop_last=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if decay:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=25,
                                                    gamma=0.7)

    def train_step(data, target):
        model.train()
        data, target = data.cuda(), target.cuda()

        prediction = model(data)
        loss = F.nll_loss(prediction, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum()

        return loss.item(), correct.item()

    def test_step(data, target):
        model.eval()
        data, target = data.cuda(), target.cuda()

        prediction = model(data)
        loss = F.nll_loss(prediction, target)

        correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum()

        return loss.item(), correct.item()

    def get_learning_rate(epoch):
        limits = [100, 200]
        lrs = [1, 0.1, 0.01]
        assert len(lrs) == len(limits) + 1
        for lim, lr in zip(limits, lrs):
            if epoch < lim:
                return lr * learning_rate
        return lrs[-1] * learning_rate

    best_acc = 0.0
    for epoch in range(epochs):
        if decay:
            scheduler.step()
        # training
        total_loss = 0
        total_correct = 0
        time_before_load = time.perf_counter()
        for batch_idx, (data, target) in enumerate(train_loader):
            time_after_load = time.perf_counter()
            time_before_step = time.perf_counter()
            loss, correct = train_step(data, target)

            total_loss += loss
            total_correct += correct

            logger.info(
                '[{}:{}/{}] LOSS={:.2} <LOSS>={:.2} ACC={:.2} <ACC>={:.2} time={:.2}+{:.2}'
                .format(epoch, batch_idx, len(train_loader), loss,
                        total_loss / (batch_idx + 1), correct / len(data),
                        total_correct / len(data) / (batch_idx + 1),
                        time_after_load - time_before_load,
                        time.perf_counter() - time_before_step))
            time_before_load = time.perf_counter()

        # test
        total_loss = 0
        total_correct = 0
        count = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            loss, correct = test_step(data, target)
            total_loss += loss
            total_correct += correct
            count += 1
        acc = total_correct / len(test_set)
        logger.info('[Epoch {} Test] <LOSS>={:.2} <ACC>={:2}'.format(
            epoch, total_loss / (count + 1), acc))

        # save the state
        state_dict_no_sparse = [
            it for it in model.state_dict().items()
            if it[1].type() != "torch.cuda.sparse.FloatTensor"
        ]
        state_dict_no_sparse = OrderedDict(state_dict_no_sparse)
        torch.save(state_dict_no_sparse, os.path.join(log_dir, "state.pkl"))

        # save the best model
        if acc > best_acc:
            shutil.copy2(os.path.join(log_dir, "state.pkl"),
                         os.path.join(log_dir, "best.pkl"))
            best_acc = acc
Exemplo n.º 2
0
def main(checkpoint_path, data_dir, dataset, partition, batch_size, feat,
         num_workers, image_shape, base_order, sample_order):

    torch.backends.cudnn.benchmark = True

    # Load the model
    loader = importlib.machinery.SourceFileLoader('model', "model.py")
    mod = types.ModuleType(loader.name)
    loader.exec_module(mod)
    num_classes = int(dataset[-2:])
    model = mod.Model(num_classes, feat=feat)
    model = nn.DataParallel(model)
    model = model.cuda()

    # load checkpoint
    ckpt = checkpoint_path
    pretrained_dict = torch.load(ckpt)
    load_partial_model(model, pretrained_dict)

    print("{} parameters in total".format(
        sum(x.numel() for x in model.parameters())))
    print("{} parameters in the last layer".format(
        sum(x.numel() for x in model.module.out_layer.parameters())))

    # Load the dataset
    # Increasing `repeat` will generate more cached files
    transform = CacheNPY(prefix='sp{}_'.format(sample_order),
                         transform=torchvision.transforms.Compose([
                             ToMesh(random_rotations=False,
                                    random_translation=0),
                             ProjectOnSphere(dataset=dataset,
                                             image_shape=image_shape,
                                             normalize=True)
                         ]))

    transform_test = CacheNPY(prefix='sp{}_'.format(sample_order),
                              transform=torchvision.transforms.Compose([
                                  ToMesh(random_rotations=False,
                                         random_translation=0),
                                  ProjectOnSphere(dataset=dataset,
                                                  image_shape=image_shape,
                                                  normalize=True)
                              ]))

    if dataset == 'modelnet10':

        def target_transform(x):
            classes = [
                'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
                'night_stand', 'sofa', 'table', 'toilet'
            ]
            return classes.index(x)
    elif dataset == 'modelnet40':

        def target_transform(x):
            classes = [
                'airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa',
                'tv_stand', 'bathtub', 'car', 'door', 'lamp', 'piano',
                'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant',
                'stool', 'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel',
                'radio', 'table', 'xbox', 'bookshelf', 'cup', 'glass_box',
                'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar',
                'night_stand', 'sink', 'toilet'
            ]
            return classes.index(x)
    else:
        print('invalid dataset. must be modelnet10 or modelnet40')
        assert (0)

    test_set = ModelNet(data_dir,
                        image_shape=image_shape,
                        base_order=base_order,
                        sample_order=sample_order,
                        dataset=dataset,
                        partition='test',
                        transform=transform_test,
                        target_transform=target_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              pin_memory=True,
                                              drop_last=False)

    def test_step(data, target):
        model.eval()
        data, target = data.cuda(), target.cuda()

        prediction = model(data)
        loss = F.nll_loss(prediction, target)

        correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum()

        return loss.item(), correct.item()

    # test
    total_loss = 0
    total_correct = 0
    count = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        loss, correct = test_step(data, target)
        total_loss += loss
        total_correct += correct
        count += 1
        print("[Test] <LOSS>={:.2} <ACC>={:2}".format(
            total_loss / (count + 1), total_correct / len(test_set)))
Exemplo n.º 3
0
def main(sp_mesh_dir, sp_mesh_level, log_dir, model_path, augmentation, decay,
         data_dir, tiny, dataset, partition, batch_size, learning_rate,
         num_workers, epochs, pretrain, feat, rand_rot):
    arguments = copy.deepcopy(locals())

    sp_mesh_file = os.path.join(sp_mesh_dir,
                                "icosphere_{}.pkl".format(sp_mesh_level))

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    shutil.copy2(__file__, os.path.join(log_dir, "script.py"))
    shutil.copy2(model_path, os.path.join(log_dir, "model.py"))

    logger = logging.getLogger("train")
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    ch = logging.StreamHandler()
    logger.addHandler(ch)
    fh = logging.FileHandler(os.path.join(log_dir, "log.txt"))
    logger.addHandler(fh)

    logger.info("%s", repr(arguments))

    torch.backends.cudnn.benchmark = True

    # Load the model
    loader = importlib.machinery.SourceFileLoader(
        'model', os.path.join(log_dir, "model.py"))
    mod = types.ModuleType(loader.name)
    loader.exec_module(mod)

    num_classes = int(dataset[-2:])
    if tiny:
        model = mod.Model_tiny(num_classes, mesh_folder=sp_mesh_dir, feat=feat)
    else:
        model = mod.Model(num_classes, mesh_folder=sp_mesh_dir, feat=feat)
    model = nn.DataParallel(model)
    model.cuda()

    if pretrain:
        pretrained_dict = torch.load(pretrain)

        def load_my_state_dict(self, state_dict, exclude='out_layer'):
            from torch.nn.parameter import Parameter

            own_state = self.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    continue
                if exclude in name:
                    continue
                if isinstance(param, Parameter):
                    # backwards compatibility for serialized parameters
                    param = param.data
                own_state[name].copy_(param)

        load_my_state_dict(model, pretrained_dict)

    logger.info("{} paramerters in total".format(
        sum(x.numel() for x in model.parameters())))
    logger.info("{} paramerters in the last layer".format(
        sum(x.numel() for x in model.module.out_layer.parameters())))

    # Load the dataset
    # Increasing `repeat` will generate more cached files
    transform = CacheNPY(prefix="sp{}_".format(sp_mesh_level),
                         transform=torchvision.transforms.Compose([
                             ToMesh(random_rotations=False,
                                    random_translation=0),
                             ProjectOnSphere(meshfile=sp_mesh_file,
                                             dataset=dataset,
                                             normalize=True)
                         ]),
                         sp_mesh_dir=sp_mesh_dir,
                         sp_mesh_level=sp_mesh_level)

    transform_test = CacheNPY(prefix="sp{}_".format(sp_mesh_level),
                              transform=torchvision.transforms.Compose([
                                  ToMesh(random_rotations=False,
                                         random_translation=0),
                                  ProjectOnSphere(meshfile=sp_mesh_file,
                                                  dataset=dataset,
                                                  normalize=True)
                              ]),
                              sp_mesh_dir=sp_mesh_dir,
                              sp_mesh_level=sp_mesh_level)

    if dataset == 'modelnet10':

        def target_transform(x):
            classes = [
                'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
                'night_stand', 'sofa', 'table', 'toilet'
            ]
            return classes.index(x)
    elif dataset == 'modelnet40':

        def target_transform(x):
            classes = [
                'airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa',
                'tv_stand', 'bathtub', 'car', 'door', 'lamp', 'piano',
                'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant',
                'stool', 'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel',
                'radio', 'table', 'xbox', 'bookshelf', 'cup', 'glass_box',
                'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar',
                'night_stand', 'sink', 'toilet'
            ]
            return classes.index(x)
    else:
        print('invalid dataset. must be modelnet10 or modelnet40')
        assert (0)

    train_set = ModelNet(data_dir,
                         dataset=dataset,
                         partition='train',
                         transform=transform,
                         target_transform=target_transform)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    test_set = ModelNet(data_dir,
                        dataset=dataset,
                        partition='test',
                        transform=transform_test,
                        target_transform=target_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers,
                                              pin_memory=True,
                                              drop_last=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if decay:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=25,
                                                    gamma=0.7)

    def train_step(data, target):
        model.train()
        data, target = data.cuda(), target.cuda()

        prediction = model(data)
        loss = F.nll_loss(prediction, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum()

        return loss.item(), correct.item()

    def test_step(data, target):
        model.eval()
        data, target = data.cuda(), target.cuda()

        prediction = model(data)
        loss = F.nll_loss(prediction, target)

        correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum()

        return loss.item(), correct.item()

    def get_learning_rate(epoch):
        limits = [100, 200]
        lrs = [1, 0.1, 0.01]
        assert len(lrs) == len(limits) + 1
        for lim, lr in zip(limits, lrs):
            if epoch < lim:
                return lr * learning_rate
        return lrs[-1] * learning_rate

    for epoch in range(epochs):
        if decay:
            scheduler.step()
        # training
        total_loss = 0
        total_correct = 0
        time_before_load = time.perf_counter()
        for batch_idx, (data, target) in enumerate(train_loader):
            time_after_load = time.perf_counter()
            time_before_step = time.perf_counter()
            loss, correct = train_step(data, target)

            total_loss += loss
            total_correct += correct

            logger.info(
                "[{}:{}/{}] LOSS={:.2} <LOSS>={:.2} ACC={:.2} <ACC>={:.2} time={:.2}+{:.2}"
                .format(epoch, batch_idx, len(train_loader), loss,
                        total_loss / (batch_idx + 1), correct / len(data),
                        total_correct / len(data) / (batch_idx + 1),
                        time_after_load - time_before_load,
                        time.perf_counter() - time_before_step))
            time_before_load = time.perf_counter()

        # test
        total_loss = 0
        total_correct = 0
        count = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            loss, correct = test_step(data, target)
            total_loss += loss
            total_correct += correct
            count += 1
        logger.info("[Epoch {} Test] <LOSS>={:.2} <ACC>={:2}".format(
            epoch, total_loss / (count + 1), total_correct / len(test_set)))

        # remove sparse matrices since they cannot be stored
        state_dict_no_sparse = [
            it for it in model.state_dict().items()
            if it[1].type() != "torch.cuda.sparse.FloatTensor"
        ]
        state_dict_no_sparse = OrderedDict(state_dict_no_sparse)
        torch.save(state_dict_no_sparse, os.path.join(log_dir, "state.pkl"))
Exemplo n.º 4
0
def main(log_dir, model_path, augmentation, dataset, num_cls, few, batch_size,
         num_workers, learning_rate):
    arguments = copy.deepcopy(locals())

    os.mkdir(log_dir)
    shutil.copy2(__file__, os.path.join(log_dir, "script.py"))
    shutil.copy2(model_path, os.path.join(log_dir, "model.py"))
    shutil.copy2(os.path.join(ROOT, "dataset.py"),
                 os.path.join(log_dir, "dataset.py"))

    logger = logging.getLogger("train")
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    ch = logging.StreamHandler()
    logger.addHandler(ch)
    fh = logging.FileHandler(os.path.join(log_dir, "log.txt"))
    logger.addHandler(fh)

    logger.info("%s", repr(arguments))

    torch.backends.cudnn.benchmark = True

    # Load the model
    loader = importlib.machinery.SourceFileLoader(
        'model', os.path.join(log_dir, "model.py"))
    mod = types.ModuleType(loader.name)
    loader.exec_module(mod)

    #model = mod.Model(55)
    model = mod.Model(num_cls)
    model.cuda()

    logger.info("{} paramerters in total".format(
        sum(x.numel() for x in model.parameters())))
    logger.info("{} paramerters in the last layer".format(
        sum(x.numel() for x in model.out_layer.parameters())))

    bw = model.bandwidths[0]

    # Load the dataset
    # Increasing `repeat` will generate more cached files
    train_transform = CacheNPY(prefix="b{}_".format(bw),
                               repeat=augmentation,
                               pick_randomly=True,
                               transform=torchvision.transforms.Compose([
                                   ToMesh(random_rotations=True,
                                          random_translation=0.1),
                                   ProjectOnSphere(bandwidth=bw)
                               ]))

    #    test_transform = torchvision.transforms.Compose([
    #        CacheNPY(prefix="b64_", repeat=augmentation, pick_randomly=False, transform=torchvision.transforms.Compose(
    #            [
    #                ToMesh(random_rotations=True, random_translation=0.1),
    #                ProjectOnSphere(bandwidth=64)
    #            ]
    #        )),
    #        lambda xs: torch.stack([torch.FloatTensor(x) for x in xs])
    #    ])
    test_transform = train_transform

    if "10" in dataset:
        train_data_type = "test"
        test_data_type = "train"
    else:
        train_data_type = "train"
        test_data_type = "test"

    train_set = ModelNet("/home/lixin/Documents/s2cnn/ModelNet",
                         dataset,
                         train_data_type,
                         few=few,
                         transform=train_transform)
    if few:
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers,
                                                   pin_memory=True,
                                                   drop_last=False)
    else:
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    test_set = ModelNet("/home/lixin/Documents/s2cnn/ModelNet",
                        dataset,
                        test_data_type,
                        transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              pin_memory=True,
                                              drop_last=False)
    optimizer = torch.optim.SGD(model.parameters(), lr=0, momentum=0.9)

    def train_step(data, target):
        model.train()
        data, target = data.cuda(), target.cuda()

        prediction = model(data)
        loss = F.nll_loss(prediction, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum()

        return loss.item(), correct.item()

    def test(epoch):
        predictions = []
        gt = []

        for batch_idx, (data, target) in enumerate(test_loader):
            model.eval()
            #batch_size, rep = data.size()[:2]
            #data = data.view(-1, *data.size()[2:])

            data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                pred = model(data).data
            #pred = pred.view(batch_size*rep, -1)
            #pred = pred.sum(1)

            predictions.append(pred.cpu().numpy())
            #gt.append([target.cpu().numpy()]*rep)
            gt.append(target.cpu().numpy())

        predictions = np.concatenate(predictions)
        gt = np.concatenate(gt)

        predictions_class = np.argmax(predictions, axis=1)
        acc = np.sum(predictions_class == gt) / len(test_set)
        logger.info("Test Acc: {}".format(acc))
        return acc

    def get_learning_rate(epoch):
        limits = [100, 200]
        lrs = [1, 0.1, 0.01]
        assert len(lrs) == len(limits) + 1
        for lim, lr in zip(limits, lrs):
            if epoch < lim:
                return lr * learning_rate
        return lrs[-1] * learning_rate

    best_acc = 0.
    for epoch in range(300):

        lr = get_learning_rate(epoch)
        logger.info("learning rate = {} and batch size = {}".format(
            lr, train_loader.batch_size))
        for p in optimizer.param_groups:
            p['lr'] = lr

        total_loss = 0
        total_correct = 0
        time_before_load = time.perf_counter()
        for batch_idx, (data, target) in enumerate(train_loader):
            time_after_load = time.perf_counter()
            time_before_step = time.perf_counter()
            loss, correct = train_step(data, target)

            total_loss += loss
            total_correct += correct

            logger.info(
                "[{}:{}/{}] LOSS={:.3} <LOSS>={:.3} ACC={:.3} <ACC>={:.3} time={:.2}+{:.2}"
                .format(epoch, batch_idx, len(train_loader), loss,
                        total_loss / (batch_idx + 1), correct / len(data),
                        total_correct / len(data) / (batch_idx + 1),
                        time_after_load - time_before_load,
                        time.perf_counter() - time_before_step))
            time_before_load = time.perf_counter()

        test_acc = test(epoch)
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(),
                       os.path.join(log_dir, "best_state.pkl"))

        torch.save(model.state_dict(), os.path.join(log_dir, "state.pkl"))
Exemplo n.º 5
0
def main(sp_mesh_dir, sp_mesh_level, log_dir, data_dir, eval_time,
         dataset, partition, batch_size, jobs, tiny, feat, no_cuda, neval):
    torch.set_num_threads(jobs)
    print("Running on {} CPU(s)".format(torch.get_num_threads()))
    if no_cuda:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True

    sp_mesh_file = os.path.join(sp_mesh_dir, "icosphere_{}.pkl".format(sp_mesh_level))


    # Load the model
    loader = importlib.machinery.SourceFileLoader('model',"model.py")
    mod = types.ModuleType(loader.name)
    loader.exec_module(mod)

    num_classes = int(dataset[-2:])
    if not tiny:
        model = mod.Model(num_classes, mesh_folder=sp_mesh_dir, feat=feat)
    else:
        model = mod.Model_tiny(num_classes, mesh_folder=sp_mesh_dir, feat=feat)

    # load checkpoint
    ckpt = os.path.join(log_dir, "state.pkl")
    if no_cuda:
        pretrained_dict = torch.load(ckpt, map_location=lambda storage, loc:storage)
    else:
        pretrained_dict = torch.load(ckpt)

    def load_my_state_dict(self, state_dict, exclude='out_layer'):
        from torch.nn.parameter import Parameter
 
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if exclude in name:
                continue
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

    load_my_state_dict(model, pretrained_dict)  
    model.to(device)
  

    print("{} paramerters in total".format(sum(x.numel() for x in model.parameters())))
    print("{} paramerters in the last layer".format(sum(x.numel() for x in model.out_layer.parameters())))

    # Load the dataset
    # Increasing `repeat` will generate more cached files
    transform = CacheNPY(prefix="sp{}_".format(sp_mesh_level), transform=torchvision.transforms.Compose(
        [
            ToMesh(random_rotations=False, random_translation=0),
            ProjectOnSphere(meshfile=sp_mesh_file, dataset=dataset, normalize=True)
        ]
    ), sp_mesh_dir=sp_mesh_dir, sp_mesh_level=sp_mesh_level)

    transform_test = CacheNPY(prefix="sp{}_".format(sp_mesh_level), transform=torchvision.transforms.Compose(
        [
            ToMesh(random_rotations=False, random_translation=0),
            ProjectOnSphere(meshfile=sp_mesh_file, dataset=dataset, normalize=True)
        ]
    ), sp_mesh_dir=sp_mesh_dir, sp_mesh_level=sp_mesh_level)

    if dataset == 'modelnet10':
        def target_transform(x):
            classes = ['bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet']
            return classes.index(x)
    elif dataset == 'modelnet40':
        def target_transform(x):
            classes = ['airplane', 'bowl', 'desk', 'keyboard', 'person', 'sofa', 'tv_stand', 'bathtub', 'car', 'door',
                       'lamp', 'piano', 'stairs', 'vase', 'bed', 'chair', 'dresser', 'laptop', 'plant', 'stool',
                       'wardrobe', 'bench', 'cone', 'flower_pot', 'mantel', 'radio', 'table', 'xbox', 'bookshelf', 'cup',
                       'glass_box', 'monitor', 'range_hood', 'tent', 'bottle', 'curtain', 'guitar', 'night_stand', 'sink', 'toilet']
            return classes.index(x)
    else:
        print('invalid dataset. must be modelnet10 or modelnet40')
        assert(0)

    test_set = ModelNet(data_dir, dataset=dataset, partition='test', transform=transform_test, target_transform=target_transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=jobs, pin_memory=True, drop_last=False)

    def test_step(data, target):
        model.eval()
        data, target = data.to(device), target.to(device)

        t = time()
        prediction = model(data)
        dt = time() - t
        loss = F.nll_loss(prediction, target)

        correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum()

        return loss.item(), correct.item(), dt

    # test
    total_loss = 0
    total_correct = 0
    count = 0
    total_time = []
    for batch_idx, (data, target) in enumerate(test_loader):
        loss, correct, dt = test_step(data, target)
        total_time.append(dt)
        total_loss += loss
        total_correct += correct
        count += 1
        if eval_time and count >= neval:
            print("Time per batch: {} secs".format(np.mean(total_time[10:])))
            break
    if not eval_time:
        print("[Test] <LOSS>={:.2} <ACC>={:2}".format(total_loss / (count+1), total_correct / len(test_set)))