Пример #1
0
def eval(test_iterator, model, params, logger, rotate=True, num_epochs=5):
    
    logger.info("================================ Eval ================================\n")
    
    s2_grids = utils.get_grids(b=params['bandwidth_0'], num_grids=params['num_grids'], base_radius=params['base_radius'])

    acc_overall = list()
    test_iterator = utils.load_data_h5(params['test_dir'], batch_size=params['batch_size'], rotate=True, batch=False)
    for epoch in range(num_epochs):
        acc_all = []
        with torch.no_grad():
            for _, (inputs, labels) in enumerate(test_iterator):
                
                inputs = Variable(inputs).cuda()
                B, N, D = inputs.size()

                if inputs.shape[-1] == 2:
                    zero_padding = torch.zeros((B, N, 1), dtype=inputs.dtype).cuda()
                    inputs = torch.cat((inputs, zero_padding), -1)  # [B, N, 3]

                # Data Mapping
                inputs = utils.data_mapping(inputs, base_radius=params['base_radius'])  # [B, N, 3]

                # Data Translation
                inputs = utils.data_translation(inputs, s2_grids,
                                                params)  # [B, N, 3] -> list( Tensor([B, 2b, 2b]) * num_grids )

                outputs = model(inputs)
                outputs = torch.argmax(outputs, dim=-1)
                acc_all.append(np.mean(outputs.detach().cpu().numpy() == labels.numpy()))
            acc_overall.append(np.mean(np.array(acc_all)))
            logger.info('[epoch {}] Accuracy: [{}]'.format(epoch, str(np.mean(np.array(acc_all)))))
            
    logger.info("======================================================================\n")
    return np.max(acc_overall)
Пример #2
0
def test(params, date_time, num_epochs=1000):
    logger = setup_logger("SphericalGMMNet")
    logger.info("Loading Data")

    # Load Data
    logger.info("Model Setting Up")

    # Model Configuration Setup
    model = SphericalGMMNet(params).cuda()
    model = model.cuda()

    logger.info('Loading the trained models from {date_time} ...'.format(
        date_time=date_time))
    model_path = os.path.join(
        params['save_dir'],
        '{date_time}-model.ckpt'.format(date_time=date_time))
    model.load_state_dict(
        torch.load(model_path, map_location=lambda storage, loc: storage))

    # Generate the grids
    # [(radius, tensor([2b, 2b, 3])) * 3]
    s2_grids = utils.get_grids(b=params['bandwidth_0'],
                               num_grids=params['num_grids'],
                               base_radius=params['base_radius'])

    test_iterator = utils.load_data_h5(params['test_dir'],
                                       batch_size=params['batch_size'],
                                       rotate=True,
                                       batch=False)
    for epoch in range(num_epochs):
        acc_all = []
        with torch.no_grad():
            for _, (inputs, labels) in enumerate(test_iterator):
                inputs = Variable(inputs).cuda()
                B, N, D = inputs.size()

                if inputs.shape[-1] == 2:
                    zero_padding = torch.zeros((B, N, 1),
                                               dtype=inputs.dtype).cuda()
                    inputs = torch.cat((inputs, zero_padding), -1)  # [B, N, 3]

                # Data Mapping
                inputs = utils.data_mapping(
                    inputs, base_radius=params['base_radius'])  # [B, N, 3]

                # Data Translation
                inputs = utils.data_translation(
                    inputs, s2_grids, params
                )  # [B, N, 3] -> list( Tensor([B, 2b, 2b]) * num_grids )

                outputs = model(inputs)
                outputs = torch.argmax(outputs, dim=-1)
                acc_all.append(
                    np.mean(outputs.detach().cpu().numpy() == labels.numpy()))

            logger.info('[epoch {}] Accuracy: [{}]'.format(
                epoch, str(np.mean(np.array(acc_all)))))
Пример #3
0
def train(params):
    # Logger Setup and OS Configuration
    logger = setup_logger("SphericalGMMNet")
    logger.info("Loading Data")

    # Load Data
    train_iterator = utils.load_data_h5(params['train_dir'], batch_size=params['batch_size'])
    test_iterator = utils.load_data_h5(params['test_dir'], batch_size=params['batch_size'], rotate=True, batch=False)

    # Model Setup
    logger.info("Model Setting Up")
    model = SphericalGMMNet(params).cuda()
    model = model.cuda()

    # Model Configuration Setup
    optim = torch.optim.Adam(model.parameters(), lr=params['baselr'])
    cls_criterion = torch.nn.CrossEntropyLoss().cuda()

    # Resume If Asked
    date_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if params['resume_training']:
        date_time = params['resume_training']
        model_path = os.path.join(params['save_dir'], '{date_time}-model.ckpt'.format(date_time=date_time))
        model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))

    # Display Parameters
    for name, value in params.items():
        logger.info("{name} : [{value}]".format(name=name, value=value))

    # Generate the grids
    # [(radius, tensor([2b, 2b, 3])) * 3]
    s2_grids = utils.get_grids(b=params['bandwidth_0'], num_grids=params['num_grids'], base_radius=params['base_radius'])

    # TODO [Visualize Grids]
    if params['visualize']:
        utils.visualize_grids(s2_grids)
    
    # Keep track of max Accuracy during training
    acc_nr, max_acc_nr = 0, 0
    acc_r, max_acc_r = 0, 0
    
    # Iterate by Epoch
    logger.info("Start Training")
    for epoch in range(params['num_epochs']):

        # Save the model for each step
        if acc_nr > max_acc_nr:
            max_acc_nr = acc_nr
            save_path = os.path.join(params['save_dir'], '{date_time}-[NR]-[{acc}]-model.ckpt'.format(date_time=date_time, acc=acc_nr))
            torch.save(model.state_dict(), save_path)
            logger.info('Saved model checkpoints into {}...'.format(save_path))
        if acc_r > max_acc_r:
            max_acc_r = acc_r
            save_path = os.path.join(params['save_dir'], '{date_time}-[R]-[{acc}]-model.ckpt'.format(date_time=date_time, acc=acc_r))
            torch.save(model.state_dict(), save_path)
            logger.info('Saved model checkpoints into {}...'.format(save_path))

        running_loss = []
        for batch_idx, (inputs, labels) in enumerate(train_iterator):

            """ Variable Setup """
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            B, N, D = inputs.size()

            if inputs.shape[-1] == 2:
                zero_padding = torch.zeros((B, N, 1), dtype=inputs.dtype).cuda()
                inputs = torch.cat((inputs, zero_padding), -1)  # [B, N, 3]

            # Data Mapping
            inputs = utils.data_mapping(inputs, base_radius=params['base_radius'])  # [B, N, 3]

            
            if params['visualize']:
                
                # TODO [Visualization [Raw]]
                origins = inputs.clone()
                utils.visualize_raw(inputs, labels)
                
                # TODO [Visualization [Sphere]]
                print("---------- Static ------------")
                params['use_static_sigma'] = True
                inputs1 = utils.data_translation(inputs, s2_grids, params)  
                utils.visualize_sphere(origins, inputs1, labels, s2_grids, params, folder='sphere')
                
                print("\n---------- Covariance ------------")
                params['use_static_sigma'] = False
                params['sigma_layer_diff'] = False
                inputs2 = utils.data_translation(inputs, s2_grids, params)  
                utils.visualize_sphere(origins, inputs2, labels, s2_grids, params, folder='sphere')
                
                print("\n---------- Layer Diff ------------")
                params['use_static_sigma'] = False
                params['sigma_layer_diff'] = True
                inputs3 = utils.data_translation(inputs, s2_grids, params)  
                utils.visualize_sphere(origins, inputs3, labels, s2_grids, params, folder='other')
                return
            else:
                # Data Translation
                inputs = utils.data_translation(inputs, s2_grids, params)  # [B, N, 3] -> list( Tensor([B, 2b, 2b]) * num_grids )

            """ Run Model """
            outputs = model(inputs)

            """ Back Propagation """
            loss = cls_criterion(outputs, labels.squeeze())
            loss.backward(retain_graph=True)
            optim.step()
            running_loss.append(loss.item())

            # Update Loss Per Batch
            logger.info("Batch: [{batch}/{total_batch}] Epoch: [{epoch}] Loss: [{loss}]".format(batch=batch_idx,
                                                                                                total_batch=len(
                                                                                                    train_iterator),
                                                                                                epoch=epoch,
                                                                                                loss=np.mean(
                                                                                                    running_loss)))

        acc_nr = eval(test_iterator, model, params, logger, rotate=False)
        logger.info(
            "**************** Epoch: [{epoch}/{total_epoch}] [NR] Accuracy: [{acc}] ****************\n".format(epoch=epoch,
                                                                                                          total_epoch=
                                                                                                          params[
                                                                                                              'num_epochs'],
                                                                                                          loss=np.mean(
                                                                                                              running_loss),
                                                                                                          acc=acc_nr))
        acc_r = eval(test_iterator, model, params, logger, rotate=True)
        logger.info(
            "**************** Epoch: [{epoch}/{total_epoch}] [R] Accuracy: [{acc}] ****************\n".format(epoch=epoch,
                                                                                                          total_epoch=
                                                                                                          params[
                                                                                                              'num_epochs'],
                                                                                                          loss=np.mean(
                                                                                                              running_loss),
                                                                                                          acc=acc_r))

    logger.info('Finished Training')