Example #1
0
def prepare(flags):
    if len(flags.GPUS) > 0:
        torch.cuda.set_device(flags.GPUS[0])
    handlers = Handlers()

    # IO configuration
    handlers.data_io = io_factory(flags)
    handlers.data_io.initialize()
    if 'sparse' in flags.IO_TYPE:
        handlers.data_io.start_threads()
        # handlers.data_io.next()
    if 'sparse' in flags.MODEL_NAME and 'sparse' not in flags.IO_TYPE:
        sys.stderr.write('Sparse UResNet needs sparse IO.')
        sys.exit(1)

    # Trainer configuration
    flags.NUM_CHANNEL = handlers.data_io.num_channels()
    handlers.trainer = trainval(flags)

    # Restore weights if necessary
    handlers.iteration = 0
    loaded_iteration = 0
    if not flags.FULL:
        loaded_iteration = handlers.trainer.initialize()
        if flags.TRAIN: handlers.iteration = loaded_iteration

    # Weight save directory
    if flags.WEIGHT_PREFIX:
        save_dir = flags.WEIGHT_PREFIX[0:flags.WEIGHT_PREFIX.rfind('/')]
        if save_dir and not os.path.isdir(save_dir): os.makedirs(save_dir)

    # Log save directory
    if flags.LOG_DIR:
        if not os.path.exists(flags.LOG_DIR): os.mkdir(flags.LOG_DIR)
        logname = '%s/train_log-%07d.csv' % (flags.LOG_DIR, loaded_iteration)
        if not flags.TRAIN:
            logname = '%s/inference_log-%07d.csv' % (flags.LOG_DIR,
                                                     loaded_iteration)
        handlers.csv_logger = utils.CSVData(logname)
        if not flags.TRAIN and flags.FULL:
            handlers.metrics_logger = utils.CSVData(
                '%s/metrics_log-%07d.csv' % (flags.LOG_DIR, loaded_iteration))
            handlers.pixels_logger = utils.CSVData(
                '%s/pixels_log-%07d.csv' % (flags.LOG_DIR, loaded_iteration))
            handlers.michel_logger = utils.CSVData(
                '%s/michel_log-%07d.csv' % (flags.LOG_DIR, loaded_iteration))
            handlers.michel_logger2 = utils.CSVData(
                '%s/michel2_log-%07d.csv' % (flags.LOG_DIR, loaded_iteration))
    return handlers
    def __init__(self,
                 broker_address,
                 plane,
                 weight_file,
                 batch_size,
                 row_tick_dim=512,
                 col_wire_dim=512,
                 device_id=None,
                 use_half=False,
                 nlayers=5,
                 nfilters=32,
                 use_compression=False,
                 **kwargs):
        """
        Constructor

        inputs
        ------
        broker_address str IP address of broker
        plane int Plane ID number. Currently [0,1,2] only
        weight_file str path to files with weights
        batch_size int number of batches to process in one pass
        """
        if type(plane) is not int:
            raise ValueError("'plane' argument must be integer for plane ID")
        elif plane not in [0, 1, 2]:
            raise ValueError("unrecognized plane argument. \
                                should be either one of [0,1,2]")
        else:
            pass

        if type(batch_size) is not int or batch_size < 0:
            raise ValueError("'batch_size' must be a positive integer")

        self.plane = plane
        self.batch_size = batch_size
        self._still_processing_msg = False
        self._use_half = use_half
        self._use_compression = use_compression
        if self._use_half:
            print("Using half-precision for sparse ssnet not tested")
            assert 1 == 2
        service_name = "sparse_uresnet_plane%d" % (self.plane)
        super(SparseSSNetWorker, self).__init__(service_name, broker_address,
                                                **kwargs)

        # Get Configs going:
        # configuration from Ran:
        """
        inference --full -pl 1 -mp PATH_TO_Plane1Weights-13999.ckpt -io larcv_sparse 
              -bs 64 -nc 5 -rs 1 -ss 512 -dd 2 -uns 5 -dkeys wire,label 
              -mn uresnet_sparse -it 10 -ld log/ -if PATH_TO_INPUT_ROOT_FILE
        """
        self.config = uresnet.flags.URESNET_FLAGS()

        args = {
            "full": True,  # --full
            "plane": self.plane,  # -pl
            "model_path": weight_file,  # -mp
            "io_type": "larcv_sparse",  # -io
            "batch_size": 1,  # -bs
            "num_class": 5,  # -nc
            "uresnet_filters": nfilters,  # -uf
            "report_step": 1,  # -rs
            "spatial_size": row_tick_dim,  # -ss
            "data_dim": 2,  # -dd
            "uresnet_num_strides": nlayers,  # -uns
            "data_keys": "wire,label",  # -dkeys
            "model_name": "uresnet_sparse",  # -mn
            "iteration": 1,  # -it
            "log_dir": "log/",  # -ld
            "input_file": "none"
        }  # -if

        self.config.update(args)
        self.config.TRAIN = False
        self.config.SPATIAL_SIZE = (row_tick_dim, col_wire_dim)

        print("\n\n-- CONFIG --")
        for name in vars(self.config):
            attribute = getattr(self.config, name)
            if type(attribute) == type(self.config.parser): continue
            print("%s = %r" % (name, getattr(self.config, name)))

        # Set random seed for reproducibility
        np.random.seed(self.config.SEED)
        torch.manual_seed(self.config.SEED)

        self.trainval = trainval(self.config)
        self.trainval.initialize()
        self._log.info("Loaded ubMRCNN model. Service={}".format(service_name))
def forwardpass( plane, sparse_bson_list, weights_filepath ):
    

    # Get Configs going:
    # configuration from Ran:
    """
    inference --full -pl 1 -mp PATH_TO_Plane1Weights-13999.ckpt -io larcv_sparse 
    -bs 64 -nc 5 -rs 1 -ss 512 -dd 2 -uns 5 -dkeys wire,label 
    -mn uresnet_sparse -it 10 -ld log/ -if PATH_TO_INPUT_ROOT_FILE
    """
    config = uresnet.flags.URESNET_FLAGS()
    args = { "full":True,               # --full
             "plane":plane,        # -pl
             "model_path":weights_filepath,  # -mp
             "io_type":"larcv_sparse",  # -io
             "batch_size":1,            # -bs
             "num_class":5,             # -nc
             "uresnet_filters":16,      # -uf
             "report_step":1,           # -rs
             "spatial_size":512,        # -ss
             "data_dim":2,              # -dd
             "uresnet_num_strides": 5,  # -uns
             "data_keys":"wire,label",  # -dkeys
             "model_name":"uresnet_sparse", # -mn
             "iteration":1,            # -it
             "log_dir":"./log/",          # -ld
             "input_file":"none" }      # -if
    config.update(args)
    config.TRAIN = False
        
    print("\n\n-- CONFIG --")
    for name in vars(config):
        attribute = getattr(config,name)
        if type(attribute) == type(config.parser): continue
        print("%s = %r" % (name, getattr(config, name)))

    # Set random seed for reproducibility
    #np.random.seed(config.SEED)
    #torch.manual_seed(config.SEED)

    interface = trainval(config)
    interface.initialize()
    print("Loaded sparse pytorch_uresnet plane={}".format(plane))

    # parse the input data, loop over pybyte objects
    sparsedata_v = []
    rseid_v = []
    npts_v  = []
    ntotalpts = 0
    for bson in sparse_bson_list:
        c_run    = c_int()
        c_subrun = c_int()
        c_event  = c_int()
        c_id     = c_int()

        imgdata = larcv.json.sparseimg_from_bson_pybytes(bson,
                                                         c_run, 
                                                         c_subrun, 
                                                         c_event, 
                                                         c_id )
        npts = int(imgdata.pixellist().size()/(imgdata.nfeatures()+2))
        ntotalpts += npts
        sparsedata_v.append(imgdata)
        npts_v.append( npts )
        rseid_v.append( (c_run.value, c_subrun.value, c_event.value, c_id.value) )
        
    # make batch array
    batch_np = np.zeros( ( ntotalpts, 4 ) )
    startidx = 0
    idx      = 0
    for npts,img2d in zip( npts_v, sparsedata_v ):
        endidx   = startidx+npts
        spimg_np = larcv.as_ndarray( img2d, larcv.msg.kNORMAL )
        #print("spimg_np: {}".format(spimg_np[:,0:2]))

        # coords
        batch_np[startidx:endidx,0] = spimg_np[:,0] # tick
        batch_np[startidx:endidx,1] = spimg_np[:,1] # wire

        batch_np[startidx:endidx,2] = idx           # batch index
        batch_np[startidx:endidx,3] = spimg_np[:,2] # pixel value
        #print("batch_np: {}".format(batch_np[:,0:2]))
        idx += 1

    # pass to network
    data_blob = { 'data': [[batch_np]] }
    results = interface.forward( data_blob )

    bson_reply = []
    startidx = 0
    for idx in xrange(len(results['softmax'])):
        ssnetout_np = results['softmax'][idx]
        #print("ssneout_np: {}".format(ssnetout_np.shape))
        rseid = rseid_v[idx]
        meta  = sparsedata_v[idx].meta(0)
        npts  = int( npts_v[idx] )
        endidx = startidx+npts
        #print("numpoints for img[{}]: {}".format(idx,npts))
        ssnetout_wcoords = np.zeros( (ssnetout_np.shape[0],ssnetout_np.shape[1]+2), dtype=np.float32 )

        ssnetout_wcoords[:,0] = batch_np[startidx:endidx,0] # tick
        ssnetout_wcoords[:,1] = batch_np[startidx:endidx,1] # wire

        # pixel value
        ssnetout_wcoords[:,2:2+ssnetout_np.shape[1]] = ssnetout_np[:,:]
        startidx = endidx
        #print("ssnetout_wcoords: {}".format(ssnetout_wcoords[:,0:2]))

        meta_v = std.vector("larcv::ImageMeta")()
        for i in xrange(5):
            meta_v.push_back(meta)
            
        ssnetout_spimg = larcv.sparseimg_from_ndarray( ssnetout_wcoords, meta_v, larcv.msg.kDEBUG )
        bson = larcv.json.as_bson_pybytes( ssnetout_spimg, rseid[0], rseid[1], rseid[2], rseid[3] )
                                          
        bson_reply.append(bson)

    return bson_reply