Пример #1
0
    def __init__(self, probability_volume, params={}):
        self.log = get_logger(self.__class__.__name__)
        self.log.info('Initialize PointSampler with volume {}'.format(
            probability_volume))

        # load probability_volume
        prob_file = nib.load(probability_volume)
        self.probs = np.array(prob_file.get_data()).flatten()
        self.probs = self.probs / self.probs.sum()
        self.probs_shape = prob_file.shape
        self.probs_spacing = prob_file.affine[0, 0]
        self.probs_offset = prob_file.affine[:3, 3]
        self.log.info(
            'Loaded probability volume with spacing {} and offset {}'.format(
                self.probs_spacing, self.probs_offset))

        # init other parameters
        for key, default_value in default_params.iteritems():
            setattr(self, key, params.get(key, default_value))
            self.log.info('Setting {} to {}'.format(key, getattr(self, key)))

        # load meshes
        self.mesh_left = Mesh(self.mesh_left_file,
                              coord_system=self.geodesic_coord_system,
                              subdivision_level=self.subdivision_level,
                              approximate=self.approximate_gdist,
                              threaded=self.threaded,
                              num_threads=self.num_threads)
        self.mesh_right = Mesh(self.mesh_right_file,
                               coord_system=self.geodesic_coord_system,
                               subdivision_level=self.subdivision_level,
                               approximate=self.approximate_gdist,
                               threaded=self.threaded,
                               num_threads=self.num_threads)

        # load transformed_coords_file
        self.transformed_coords = h5py.File(self.transformed_coords_file,
                                            'r')['coords']
        self.transformed_coords_sections = list(
            self.transformed_coords.attrs['sections'])
        if self.transformed_coords.attrs['spacing'] != self.probs_spacing:
            self.log.error(
                'Probs spacing %d and transformed coords spacing %d are different! Potential problems!!',
                self.probs_spacing, self.transformed_coords.attrs['spacing'])
        self.log.info('Loaded transformed coords file')

        if self.deterministic:
            # seed random generator to get same points each time
            np.random.seed(0)

        self.next = self.point_pair_iterator().next
Пример #2
0
 def __init__(self,
              mesh_file,
              coord_system=None,
              approximate=False,
              subdivision_level=0,
              threaded=True,
              num_threads=1,
              inflated_file=None):
     self.log = get_logger(self.__class__.__name__)
     self.log.info('Initializing mesh for {}'.format(mesh_file))
     # load data using Konrads mesh_io
     data = mesh_io.load_mesh_geometry(mesh_file)
     self.coords = data['coords'].astype(np.float64)
     self.triangs = data['faces'].astype(np.int32)
     # previously loaded data with
     #data = nib.load(mesh_file)
     #self.coords = data.darrays[0].data.astype(np.float64)
     #self.triangs = data.darrays[1].data.astype(np.int32)
     self.mesh = trimesh.Trimesh(self.coords, self.triangs)
     self.gdist_algo = gdist.GeodesicAlgorithm(
         self.coords,
         self.triangs,
         approximate=approximate,
         subdivision_level=subdivision_level)
     self.threaded = threaded
     self.num_threads = num_threads
     self.coords_2d = None
     if coord_system is not None:
         self.log.info('Loading geodesic coordinate system from {}'.format(
             coord_system))
         self.coords_2d = self.load_geodesic_coord_system(coord_system)
     if inflated_file is not None:
         self.log.info(
             'Loading inflated mesh from {}'.format(inflated_file))
         inflated_data = mesh_io.load_mesh_geometry(inflated_file)
         self.inflated_coords = inflated_data['coords']
     # create mapping of trimesh coords to mesh coords
     # sort both arrays (save indices)
     coords_ind = np.argsort(self.coords.view('f8,f8,f8'),
                             order=['f0', 'f1', 'f2'],
                             axis=0).flatten()
     trimesh_ind = np.argsort(self.mesh.vertices.view('f8,f8,f8'),
                              order=['f0', 'f1', 'f2'],
                              axis=0).flatten()
     # sort the indices of trimesh (trimesh_ind[trimesh_ind_argsort] == [0,1,2,...])
     trimesh_ind_argsort = np.argsort(trimesh_ind, axis=0)
     # mapping of trimesh index to mesh index mesh[trimesh_to_mesh[0]] = trimesh[0]
     self.trimesh_to_coord = coords_ind[trimesh_ind_argsort]
Пример #3
0
    "load_last": False,
    "train_dir": "train",
    "eval_dir": "eval",
    "data_dir": "data",
    "log_dir": "log",
    "best_dir": "best",
    "train_data": "LCCC-train-small.txt",
    "valid_data": "LCCC-valid-small.txt",
    "test_data": "LCCC-test.txt",
    "cgpt_parameters_dir": "chinese_gpt_original/Cgpt_model.bin"
}

train_dir = 'train'
data_dir = 'data'
log_dir = 'log'
logger = get_logger('main.log')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

try:

    model_path = os.path.join(train_dir, get_latest_ckpt(train_dir))

    if not os.path.isfile(model_path):
        print('cannot find {}'.format(model_path))
        exit(0)

    device = torch.device("cuda")

    vocab = Vocab(config['vocab_path'])

    print('Building models')
Пример #4
0
from __future__ import print_function, division, absolute_import
import numpy as np
import math
from my_utils import get_logger

__all__ = ["BoundingBox"]

logger = get_logger(__name__)


class BoundingBox:
    def __init__(self, center, w, h, **kwargs):
        """A BoundingBox is a representation for a rectangle.

        It can be scaled, painted to an array and used to crop from an array.
        Internally, it knows a spacing, so if it is resized, the spacing is updated accordingly.

        Args:
            center (tuple): (x,y) coordinated of center.
            w (number): width of box
            h (number): height of box
            kwargs:
                spacing (float): spacing of the bounding box coordinates.
        """
        assert len(center) == 2
        self.center = (float(center[0]), float(center[1]))
        self.w = float(w)
        self.h = float(h)
        self.spacing = 1.
        if 'spacing' in kwargs:
            self.spacing = float(kwargs['spacing'])
Пример #5
0
    def __init__(self,
                 net,
                 params,
                 train_batch_iter,
                 test_batch_iter,
                 timestamp=''):
        """ Create Tensorflow Neural Network Trainer.

        Args:
            net: Keras model (created by build_net in net_definition.py)
            params (dict): training parameters (from config.py)
            train_batch_iter (BatchIterator): iterator over training data
            test_batch_iter (BatchIterator)
        """
        self.log = get_logger(self.__class__.__name__,
                              rank=MPI.COMM_WORLD.Get_rank())
        self.log.info('Initializing NNTrainer')
        self.net = net
        self.train_batch_iter = train_batch_iter
        self.test_batch_iter = test_batch_iter
        # set params
        for name in default_train_params.keys():
            setattr(self, name, params.get(name, default_train_params[name]))
            self.log.info('Setting {} to {:.200}'.format(
                name, '{}'.format(getattr(self, name))))
        if not isinstance(self.loss_name, list):
            self.loss_name = [self.loss_name]
        for i, metric_name in enumerate(self.metrics_names):
            if not isinstance(metric_name, list):
                self.metrics_names[i] = [
                    metric_name for _ in range(len(self.net.outputs))
                ]
        for param in ('huber_delta', 'loss_weight', 'switch_c1', 'switch_c2',
                      'switch_c3', 'switch_value'):
            if not isinstance(getattr(self, param), list):
                setattr(self, param,
                        [getattr(self, param)] * len(self.loss_name))
        if not os.path.exists(self.snapshot_dir):
            os.makedirs(self.snapshot_dir)
        self.log_dir = os.path.join(self.log_dir, timestamp)

        self.train_summary = None
        self.test_summary = None
        self.global_step_var = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)
        # variables and operations for training
        if self.mode == 'regression':
            self.target_var = []
            for i, output in enumerate(self.net.outputs):
                self.target_var.append(
                    tf.placeholder(tf.float32,
                                   shape=output.shape.as_list(),
                                   name='target%d' % i))
        elif self.mode == 'segmentation':
            self.target_var = [
                tf.placeholder(tf.int32,
                               shape=self.net.outputs[0].shape.as_list()[:-1],
                               name='target')
            ]
        else:
            raise NotImplementedError
        self.metrics_ops = self.get_metrics()
        self.loss_op = self.get_loss()
        self.train_op, self.lr_var = self.get_train_step()
        self.train_update_op = tf.group(*self.net.get_train_updates())
        self.prediction_op = self.get_prediction()
        self.init_op = None
        self.checkpoint_saver = None
        self.sess = None

        # summary operations
        self.train_summary_op = None
        self.test_summary_op = None
        self.eval_image_var = None
        self.test_metrics_vars = None
        self.test_loss_var = None
        self.add_summary()

        # initialize variables for finding best model (stop_criterium: max_global)
        self.best_metric = -1000
        self.best_iteration = 0

        # initialize variables and saver
        self.init_op = tf.group(tf.global_variables_initializer(),
                                tf.local_variables_initializer())

        self.checkpoint_saver = tf.train.Saver(max_to_keep=5)
Пример #6
0
def main():
    logger = get_logger(DIR)

    logger.info('--- start ---')

    logger.info('--- end ---')