Beispiel #1
0
def forwardpassy(data_t, CHECKPOINT_FILE=None):
    # function to load the sparse infill network and run a forward pass of one image
    # make tensor for coords (row,col,batch)
    ncoords = np.size(data_t, 0)
    coord_t = torch.ones((ncoords, 3), dtype=torch.int)
    isdead = np.equal(data_t, np.zeros((ncoords, 3)))
    isdeadnum = np.sum(isdead)
    if (isdeadnum == 0):
        print("SKIPPED DUE TO NO DEAD CHANNELS")
        return data_t
    # tensor for input pixels
    input_t = torch.zeros((ncoords, 1), dtype=torch.float)

    coord_t[0:ncoords,0:2] \
        = torch.from_numpy(data_t[:,0:2].astype(np.int) )
    input_t[0:ncoords, 0] = torch.from_numpy(data_t[:, 2])
    # print coord_t
    # print input_t

    # loading model with hard coded parameters used in training
    # ( (height,width),reps,ninput_features, noutput_features,nplanes, show_sizes=False)
    if CHECKPOINT_FILE is None:
        CHECKPOINT_FILE = "/mnt/disk1/nutufts/kmason/sparsenet/ubdl/sparse_infill/sparse_infill/training/sparseinfill_yplane_test.tar"
    model = SparseInfill((512, 496), 1, 16, 16, 5, show_sizes=False)

    # load checkpoint data
    checkpoint = torch.load(CHECKPOINT_FILE, {
        "cuda:0": "cpu",
        "cuda:1": "cpu"
    })

    for name, arr in checkpoint["state_dict"].items():
        if (("unet" in name and "weight" in name and len(arr.shape) == 3) or
            ("conv2" in name and "weight" in name and len(arr.shape) == 3) or
            ("conv1" in name and "weight" in name and len(arr.shape) == 3)
                or ("sparseModel" in name and "weight" in name
                    and len(arr.shape) == 3)):
            #print("reshaping ",name)
            checkpoint["state_dict"][name] = arr.reshape(
                (arr.shape[0], 1, arr.shape[1], arr.shape[2]))

    model.load_state_dict(checkpoint["state_dict"])
    model.eval()
    loadedmodeltime = time.time()

    # run the forward pass
    with torch.set_grad_enabled(False):
        out_t = model(coord_t, input_t, 1)
    forwardpasstime = time.time()
    out_t = out_t.data.numpy()
    predict_t = np.zeros((ncoords, 3), dtype=np.float32)
    predict_t[:, 2] = out_t[:, 0]
    predict_t[:, 0:2] = data_t[:, 0:2]
    print("forwardpass: ", forwardpasstime - loadedmodeltime)

    return predict_t
Beispiel #2
0
def load_model(CHECKPOINT_FILE, DEVICE):
    model = SparseInfill((512, 496), 1, 16, 16, 5,
                         show_sizes=False).to(torch.device(DEVICE))

    checkpoint = torch.load(CHECKPOINT_FILE, {
        "cuda:0": DEVICE,
        "cuda:1": DEVICE
    })
    for name, arr in checkpoint["state_dict"].items():
        if (("unet" in name and "weight" in name and len(arr.shape) == 3) or
            ("conv2" in name and "weight" in name and len(arr.shape) == 3) or
            ("conv1" in name and "weight" in name and len(arr.shape) == 3)
                or ("sparseModel" in name and "weight" in name
                    and len(arr.shape) == 3)):
            #print("reshaping ",name)
            checkpoint["state_dict"][name] = arr.reshape(
                (arr.shape[0], 1, arr.shape[1], arr.shape[2]))

    model.load_state_dict(checkpoint["state_dict"])
    model.eval()
    #print(model)
    return model
def main():

    global best_prec1
    global writer
    global num_iters

    if GPUMODE:
        DEVICE = torch.device("cuda:%d"%(DEVICE_IDS[0]))
    else:
        DEVICE = torch.device("cpu")

    # create model, mark it to run on the GPU
    imgdims = 2
    ninput_features  = 16
    noutput_features = 16
    nplanes = 5
    reps = 1
    # self, inputshape, reps, nin_features, nout_features, nplanes,show_sizes
    model = SparseInfill( (IMAGE_HEIGHT,IMAGE_WIDTH), reps,
                           ninput_features, noutput_features,
                           nplanes, show_sizes=False).to(DEVICE)

    # Resume training option
    if RESUME_FROM_CHECKPOINT:
        print "RESUMING FROM CHECKPOINT FILE ",CHECKPOINT_FILE
        checkpoint = torch.load( CHECKPOINT_FILE, map_location=CHECKPOINT_MAP_LOCATIONS ) # load weights to gpuid
        best_prec1 = checkpoint["best_prec1"]
        if CHECKPOINT_FROM_DATA_PARALLEL:
            model = nn.DataParallel( model, device_ids=DEVICE_IDS ) # distribute across device_ids
        model.load_state_dict(checkpoint["state_dict"])

    if not CHECKPOINT_FROM_DATA_PARALLEL and len(DEVICE_IDS)>1:
        model = nn.DataParallel( model, device_ids=DEVICE_IDS ).to(device=DEVICE) # distribute across device_ids

    # uncomment to dump model
    if False:
        print "Loaded model: ",model
        return

    # define loss function (criterion) and optimizer
    criterion = SparseInfillLoss().to(device=DEVICE)

    # training parameters
    lr = 1.0e-4
    momentum = 0.9
    weight_decay = 1.0e-4

    # training length
    batchsize_train = BATCHSIZE_TRAIN
    batchsize_valid = BATCHSIZE_VALID#*len(DEVICE_IDS)
    start_epoch = 0
    epochs      = 10
    iter_per_epoch = None # determined later
    iter_per_valid = 10


    nbatches_per_itertrain = 5
    itersize_train         = batchsize_train*nbatches_per_itertrain
    trainbatches_per_print = -1

    nbatches_per_itervalid = 5
    itersize_valid         = batchsize_valid*nbatches_per_itervalid
    validbatches_per_print = -1

    # SETUP OPTIMIZER

    # SGD w/ momentum
    #optimizer = torch.optim.SGD(model.parameters(), lr,
    #                            momentum=momentum,
    #                            weight_decay=weight_decay)

    # ADAM
    # betas default: (0.9, 0.999) for (grad, grad^2). smoothing coefficient for grad. magnitude calc.
    # optimizer = torch.optim.Adam(model.parameters(),
    #                             lr=lr,
    #                             weight_decay=weight_decay)
    # RMSPROP
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay)

    # optimize algorithms based on input size (good if input size is constant)
    cudnn.benchmark = True

    # LOAD THE DATASET
    iotrain = load_infill_larcvdata( "train_adc", INPUTFILE_TRAIN,
                                      BATCHSIZE_TRAIN, NWORKERS_TRAIN,
                                      input_producer_name="ADCMasked",
                                      true_producer_name="ADC",
                                      plane = 0,
                                      tickbackward=TICKBACKWARD,
                                      readonly_products=None )
    iovalid = load_infill_larcvdata( "valid_adc", INPUTFILE_TRAIN,
                                      BATCHSIZE_TRAIN, NWORKERS_TRAIN,
                                      input_producer_name="ADCMasked",
                                      true_producer_name="ADC",
                                      plane = 0,
                                      tickbackward=TICKBACKWARD,
                                      readonly_products=None )

    print "pause to give time to feeders"

    NENTRIES = len(iotrain)
    #NENTRIES = 100000
    print "Number of entries in training set: ",NENTRIES

    if NENTRIES>0:
        iter_per_epoch = NENTRIES/(itersize_train)
        if num_iters is None:
            # we set it by the number of request epochs
            num_iters = (epochs-start_epoch)*NENTRIES
        else:
            epochs = num_iters/NENTRIES
    else:
        iter_per_epoch = 1

    print "Number of epochs: ",epochs
    print "Iter per epoch: ",iter_per_epoch

    with torch.autograd.profiler.profile(enabled=RUNPROFILER) as prof:

        # Resume training option
        if RESUME_FROM_CHECKPOINT:
           print "RESUMING FROM CHECKPOINT FILE ",CHECKPOINT_FILE
           checkpoint = torch.load( CHECKPOINT_FILE, map_location=CHECKPOINT_MAP_LOCATIONS )
           best_prec1 = checkpoint["best_prec1"]
           model.load_state_dict(checkpoint["state_dict"])
           optimizer.load_state_dict(checkpoint['optimizer'])
        # if GPUMODE:
        #    optimizer.cuda(GPUID)

        for ii in range(start_iter, num_iters):

            adjust_learning_rate(optimizer, ii, lr)
            print "MainLoop Iter:%d Epoch:%d.%d "%(ii,ii/iter_per_epoch,ii%iter_per_epoch),
            for param_group in optimizer.param_groups:
                print "lr=%.3e"%(param_group['lr']),
                print

            # train for one iteration
            try:
                _ = train(iotrain, DEVICE, BATCHSIZE_TRAIN, model,
                          criterion, optimizer,
                          nbatches_per_itertrain, ii, trainbatches_per_print)

            except Exception,e:
                print "Error in training routine!"
                print e.message
                print e.__class__.__name__
                traceback.print_exc(e)
                break

            # evaluate on validation set
            if ii%iter_per_valid==0 and ii>0:
                try:
                    totloss, acc5 = validate(iovalid, DEVICE, BATCHSIZE_VALID, model,
                              criterion, optimizer,
                              nbatches_per_itervalid, ii, validbatches_per_print)
                except Exception,e:
                    print "Error in validation routine!"
                    print e.message
                    print e.__class__.__name__
                    traceback.print_exc(e)
                    break

                # remember best prec@1 and save checkpoint
                prec1   =acc5
                is_best =  prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)

                # check point for best model
                if is_best:
                    print "Saving best model"
                    save_checkpoint({
                        'iter':ii,
                        'epoch': ii/iter_per_epoch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer' : optimizer.state_dict(),
                    }, is_best, -1)

            # periodic checkpoint
            if ii>0 and ii%ITER_PER_CHECKPOINT==0:
                print "saving periodic checkpoint"
                save_checkpoint({
                    'iter':ii,
                    'epoch': ii/iter_per_epoch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer' : optimizer.state_dict(),
                }, False, ii)
            # flush the print buffer after iteration
            sys.stdout.flush()
def forwardpass(sparseimg_bson_list, checkpoint_file):
    """ function to load the sparse infill network and run a forward pass of one image
    make tensor for coords (row,col,batch). expects an input consisting of a list of pybyte 
    objects containing json versions of SparseImage """
    print("[INFILL] forward pass")

    print("[INFILL] Load modules: ROOT")
    from ROOT import std
    print("[INFILL] Load modules: larcv")
    from larcv import larcv
    print("[INFILL] Load modules: ublarcapp")
    from ublarcvapp import ublarcvapp
    print("[INFILL] Load modules: load_jsonutils")
    larcv.json.load_jsonutils

    batchsize = 1
    starttime = time.time()

    # load the model
    print("[INFILL] define sparse model")
    model = SparseInfill((512, 496), 1, 16, 16, 5, show_sizes=False)

    # load checkpoint data
    print("[INFILL] load checkpoint file {}".format(checkpoint_file))
    checkpoint = torch.load(checkpoint_file, {
        "cuda:0": "cpu",
        "cuda:1": "cpu"
    })
    best_prec1 = checkpoint["best_prec1"]
    model.load_state_dict(checkpoint["state_dict"])
    loadedmodeltime = time.time() - starttime
    print("[INFILL] model loading time: %.2f secs" % (loadedmodeltime))

    # parse the input data, loop over pybyte objects
    sparsedata_v = []
    rse_v = []
    ntotalpts = 0
    for bson in sparseimg_bson_list:
        c_run = c_int()
        c_subrun = c_int()
        c_event = c_int()
        c_id = c_int()

        imgdata = larcv.json.sparseimg_from_bson_pybytes(
            bson, c_run, c_subrun, c_event, c_id)

        sparsedata_v.append(imgdata)
        rse_v.append((c_run.value, c_subrun.value, c_event.value, c_id.value))

    nbatches = len(sparsedata_v) / batchsize
    if len(sparsedata_v) % batchsize != 0:
        nbatches += 1

    iimgs = 0
    tloadbatchdata = 0.
    trunmodel = 0.
    tpackoutput = 0.

    bson_results_v = []

    # loop over batches
    for ibatch in range(nbatches):

        starttime = time.time()

        # count the number of pts in each image of the batch
        totalpts = 0
        npts_v = []
        for ib in range(batchsize):
            img_idx = ibatch * batchsize + ib
            if img_idx >= len(sparsedata_v):
                continue

            img = sparsedata_v[img_idx]
            imgpts = int(img.pixellist().size() / (img.nfeatures() + 2))
            print("pixlist={} nfeatures={} npts={}".format(
                img.pixellist().size(), img.nfeatures(), imgpts))
            totalpts += imgpts
            npts_v.append(imgpts)
        print("Pt totals: ", npts_v)

        # prepare numpy array for batch
        coord_np = np.zeros((totalpts, 3), dtype=np.int)
        input_np = np.zeros((totalpts, 1), dtype=np.float32)

        # fill data into batch arrays
        nfilled = 0
        for ib in range(batchsize):
            img_idx = ibatch * batchsize + ib
            if img_idx >= len(sparsedata_v):
                continue

            img_np = larcv.as_ndarray(sparsedata_v[img_idx], larcv.msg.kNORMAL)
            start = nfilled
            end = nfilled + npts_v[ib]

            coord_np[start:end, 0:2] = img_np[:, 0:2].astype(np.int)
            coord_np[start:end, 2] = ib
            input_np[start:end, 0] = img_np[:, 2]
            nfilled = end

        # convert to torch
        coord_t = torch.from_numpy(coord_np).to(torch.device("cpu"))
        input_t = torch.from_numpy(input_np).to(torch.device("cpu"))

        tloadbatchdata += time.time() - starttime

        starttime = time.time()
        # run through the model
        with torch.set_grad_enabled(False):
            out_t = model(coord_t, input_t, batchsize)

        out_np = out_t.detach().cpu().numpy()
        trunmodel += time.time() - starttime

        # packoutput back into bson
        starttime = time.time()

        # loop over batch
        nfilled = 0

        for ib in range(batchsize):

            img_idx = ibatch * batchsize + ib
            if img_idx >= len(sparsedata_v):
                continue

            rsei = rse_v[img_idx]
            meta = sparsedata_v[img_idx].meta_v().front()

            outmeta_v = std.vector("larcv::ImageMeta")()
            outmeta_v.push_back(meta)

            start = nfilled
            end = nfilled + npts_v[ib]

            sparse_np = np.zeros((npts_v[ib], 3), dtype=np.float32)
            sparse_np[:, 0:2] = coord_np[start:end, 0:2]
            sparse_np[:, 2] = out_np[start:end, 0]

            # make the sparseimage object
            sparseimg = larcv.sparseimg_from_ndarray(sparse_np, outmeta_v,
                                                     larcv.msg.kNORMAL)

            # convert to bson string
            print("[{}] {}".format(img_idx, rsei))
            bson = larcv.json.as_bson_pybytes(sparseimg, rsei[0], rsei[1],
                                              rsei[2], rsei[3])

            # store in output list
            bson_results_v.append(bson)

        tpackoutput = time.time() - starttime

    return bson_results_v
Beispiel #5
0
def main():
    inputfiles = ["/mnt/disk1/nutufts/kmason/data/sparseinfill_data_test.root"]
    outputfile = ["sparseoutput.root"]
    CHECKPOINT_FILE = "../training/vplane_24000.tar"

    trueadc_h = TH1F('adc value', 'adc value', 100, 0., 100.)
    predictadc_h = TH1F('adc value', 'adc value', 100, 0., 100.)
    diffs2d_thresh_h = TH2F('h2', 'diff2d', 90, 10., 100., 90, 10., 100.)
    diff2d_h = TH2F('h2', 'diff2d', 100, 0., 100., 100, 0., 100.)

    if GPUMODE:
        DEVICE = torch.device("cuda:%d" % (DEVICE_IDS[0]))
    else:
        DEVICE = torch.device("cpu")

    iotest = load_infill_larcvdata("infillsparsetest",
                                   inputfiles,
                                   batchsize,
                                   nworkers,
                                   "ADCMasked",
                                   "ADC",
                                   plane,
                                   tickbackward=tickbackward,
                                   readonly_products=readonly_products)

    inputmeta = larcv.IOManager(larcv.IOManager.kREAD)
    inputmeta.add_in_file(inputfiles[0])
    inputmeta.initialize()
    # setup model
    model = SparseInfill((image_height, image_width),
                         reps,
                         ninput_features,
                         noutput_features,
                         nplanes,
                         show_sizes=False).to(DEVICE)

    # load checkpoint data
    checkpoint = torch.load(
        CHECKPOINT_FILE,
        map_location=CHECKPOINT_MAP_LOCATIONS)  # load weights to gpuid
    best_prec1 = checkpoint["best_prec1"]
    model.load_state_dict(checkpoint["state_dict"])

    tstart = time.time()
    # values for average accuracies
    totacc2 = 0
    totacc5 = 0
    totacc10 = 0
    totacc20 = 0
    totalbinacc = 0

    # output IOManager
    outputdata = larcv.IOManager(larcv.IOManager.kWRITE, "IOManager",
                                 larcv.IOManager.kTickForward)
    outputdata.set_out_file("sparseoutput.root")
    outputdata.initialize()

    # save to output file
    ev_out_ADC = outputdata.get_data(larcv.kProductImage2D, "ADC")
    ev_out_input = outputdata.get_data(larcv.kProductImage2D, "Input")
    ev_out_output = outputdata.get_data(larcv.kProductImage2D, "Output")
    ev_out_overlay = outputdata.get_data(larcv.kProductImage2D, "Overlay")

    totaltime = 0
    for n in xrange(nentries):
        starttime = time.time()
        print "On entry: ", n
        inputmeta.read_entry(n)
        ev_meta = inputmeta.get_data(larcv.kProductSparseImage, "ADC")
        outmeta = ev_meta.SparseImageArray()[0].meta_v()
        model.eval()

        infilldict = iotest.get_tensor_batch(DEVICE)
        coord_t = infilldict["coord"]
        input_t = infilldict["ADCMasked"]
        true_t = infilldict["ADC"]

        # run through model
        predict_t = model(coord_t, input_t, 1)
        forwardpasstime = time.time()

        predict_t.detach().cpu().numpy()
        input_t.detach().cpu().numpy()
        true_t.detach().cpu().numpy()

        # calculate accuracies
        labels = input_t.eq(0).float()
        chargelabel = labels * (true_t > 0).float()
        totaldeadcharge = chargelabel.sum().float()
        totaldead = labels.sum().float()
        predictdead = labels * predict_t
        truedead = true_t * labels
        predictcharge = chargelabel * predict_t
        truecharge = chargelabel * true_t
        err = (predictcharge - truecharge).abs()

        totacc2 += (err.lt(2).float() *
                    chargelabel.float()).sum().item() / totaldeadcharge
        totacc5 += (err.lt(5).float() *
                    chargelabel.float()).sum().item() / totaldeadcharge
        totacc10 += (err.lt(10).float() *
                     chargelabel.float()).sum().item() / totaldeadcharge
        totacc20 += (err.lt(20).float() *
                     chargelabel.float()).sum().item() / totaldeadcharge

        bineq0 = (truedead.eq(0).float() * predictdead.eq(0).float() *
                  labels).sum().item()
        bingt0 = (truedead.gt(0).float() *
                  predictdead.gt(0).float()).sum().item()
        totalbinacc += (bineq0 + bingt0) / totaldead

        # construct dense images
        ADC_img = larcv.Image2D(image_width, image_height)
        Input_img = larcv.Image2D(image_width, image_height)
        Output_img = larcv.Image2D(image_width, image_height)
        Overlay_img = larcv.Image2D(image_width, image_height)

        ADC_img, Input_img, Output_img, Overlay_img, trueadc_h, predictadc_h, diff2d_h, diffs2d_thresh_h = pixelloop(
            true_t, coord_t, predict_t, input_t, ADC_img, Input_img,
            Output_img, Overlay_img, trueadc_h, predictadc_h, diff2d_h,
            diffs2d_thresh_h)

        ev_out_ADC.Append(ADC_img)
        ev_out_input.Append(Input_img)
        ev_out_output.Append(Output_img)
        ev_out_overlay.Append(Overlay_img)

        outputdata.set_id(ev_meta.run(), ev_meta.subrun(), ev_meta.event())
        outputdata.save_entry()
        endentrytime = time.time()
        print "total entry time: ", endentrytime - starttime
        print "forward pass time: ", forwardpasstime - starttime
        totaltime += forwardpasstime - starttime

    avgacc2 = (totacc2 / nentries) * 100
    avgacc5 = (totacc5 / nentries) * 100
    avgacc10 = (totacc10 / nentries) * 100
    avgacc20 = (totacc20 / nentries) * 100
    avgbin = (totalbinacc / nentries) * 100

    tend = time.time() - tstart
    print "elapsed time, ", tend, "secs ", tend / float(nentries), " sec/batch"
    print "average forward pass time: ", totaltime / nentries
    print "--------------------------------------------------------------------"
    print " For dead pixels that should have charge..."
    print "<2 ADC of true: ", avgacc2.item(), "%"
    print "<5 ADC of true: ", avgacc5.item(), "%"
    print "<10 ADC of true: ", avgacc10.item(), "%"
    print "<20 ADC of true: ", avgacc20.item(), "%"
    print "binary acc in dead: ", avgbin.item(), "%"
    print "--------------------------------------------------------------------"

    # create canvas to save as pngs

    # ADC values
    rt.gStyle.SetOptStat(0)
    c1 = TCanvas("ADC Values", "ADC Values", 600, 600)
    trueadc_h.GetXaxis().SetTitle("ADC Value")
    trueadc_h.GetYaxis().SetTitle("Number of pixels")
    c1.UseCurrentStyle()
    trueadc_h.SetLineColor(632)
    predictadc_h.SetLineColor(600)
    c1.SetLogy()
    trueadc_h.Draw()
    predictadc_h.Draw("SAME")
    legend = TLegend(0.1, 0.7, 0.48, 0.9)
    legend.AddEntry(trueadc_h, "True Image", "l")
    legend.AddEntry(predictadc_h, "Output Image", "l")
    legend.Draw()
    c1.SaveAs(("ADCValues.png"))

    # 2d ADC difference histogram
    c2 = TCanvas("diffs2D", "diffs2D", 600, 600)
    c2.UseCurrentStyle()
    line = TLine(0, 0, 80, 80)
    line.SetLineColor(632)
    diff2d_h.SetOption("COLZ")
    c2.SetLogz()
    diff2d_h.GetXaxis().SetTitle("True ADC value")
    diff2d_h.GetYaxis().SetTitle("Predicted ADC value")
    diff2d_h.Draw()
    line.Draw()
    c2.SaveAs(("diffs2d.png"))

    # 2d ADC difference histogram - thresholded
    c3 = TCanvas("diffs2D_thresh", "diffs2D_thresh", 600, 600)
    c3.UseCurrentStyle()
    line = TLine(10, 10, 80, 80)
    line.SetLineColor(632)
    diffs2d_thresh_h.SetOption("COLZ")
    diffs2d_thresh_h.GetXaxis().SetTitle("True ADC value")
    diffs2d_thresh_h.GetYaxis().SetTitle("Predicted ADC value")
    diffs2d_thresh_h.Draw()
    line.Draw()
    c3.SaveAs(("diffs2d_thresh.png"))
    # save results
    outputdata.finalize()
class UBInfillSparseWorker(MDPyWorkerBase):

    def __init__(self,broker_address,plane,
                 weight_file,device,batch_size,
                 use_half=False,use_compression=False,
                 **kwargs):
        """
        Constructor

        inputs
        ------
        broker_address str IP address of broker
        plane int Plane ID number. Currently [0,1,2] only
        weight_file str path to files with weights
        batch_size int number of batches to process in one pass
        """
        if type(plane) is not int:
            raise ValueError("'plane' argument must be integer for plane ID")
        elif plane not in [0,1,2]:
            raise ValueError("unrecognized plane argument. \
                                should be either one of [0,1,2]")
        else:
            print("PLANE GOOD: ", plane)
            pass

        if type(batch_size) is not int or batch_size<0:
            raise ValueError("'batch_size' must be a positive integer")

        self.plane = plane
        self.batch_size = batch_size
        self._still_processing_msg = False
        self._use_half = use_half
        self._use_compression = use_compression

        service_name = "infill_plane%d"%(self.plane)

        super(UBInfillSparseWorker,self).__init__( service_name,
                                            broker_address, **kwargs)

        self.load_model(weight_file,device,self._use_half)

        if self.is_model_loaded():
            self._log.info("Loaded ubInfill model. Service={}"\
                            .format(service_name))

    def load_model(self,weight_file,device,use_half):
        # import pytorch
        try:
            import torch
        except:
            raise RuntimeError("could not load pytorch!")

        # ----------------------------------------------------------------------
        # import model - change to my model
        sys.path.append("../../../networks/infill")
        from sparseinfill import SparseInfill

        if "cuda" not in device and "cpu" not in device:
            raise ValueError("invalid device name [{}]. Must str with name \
                                \"cpu\" or \"cuda:X\" where X=device number")

        self.device = torch.device(device)

        map_location = {"cuda:0":"cpu","cuda:1":"cpu"}
        self.model = SparseInfill( (512,496), 1,16,16,5, show_sizes=False)
        checkpoint = torch.load( weight_file, map_location=map_location )
        from_data_parallel = False
        for k,v in checkpoint["state_dict"].items():
            if "module." in k:
                from_data_parallel = True
                break

        if from_data_parallel:
            new_state_dict = OrderedDict()
            for k, v in checkpoint["state_dict"].items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            checkpoint["state_dict"] = new_state_dict

        self.model.load_state_dict(checkpoint["state_dict"])
        self.model.to(self.device)
        self.model.eval()


        print ("Loaded Model!")
        # ----------------------------------------------------------------------


    def make_reply(self,request,nreplies):
        """we load each image and pass it through the net.
        we run one batch before sending off partial reply.
        the attribute self._still_processing_msg is used to tell us if we
        are still in the middle of a reply.
        """
        #print("DummyPyWorker. Sending client message back")
        self._log.debug("received message with {} parts".format(len(request)))

        if not self.is_model_loaded():
            self._log.debug("model not loaded for some reason. loading.")

        try:
            import torch
        except:
            raise RuntimeError("could not load pytorch!")

        try:
            from ROOT import std
        except:
            raise RuntimeError("could not load ROOT.std")

        # message pattern: [image_bson,image_bson,...]

        nmsgs = len(request)
        nbatches = nmsgs/self.batch_size

        if not self._still_processing_msg:
            self._next_msg_id = 0

        # turn message pieces into numpy arrays
        imgdata_v  = []
        sizes    = []
        frames_used = []
        rseid_v = []
        totalpts =0
        for imsg in xrange(self._next_msg_id,nmsgs):
            try:
                compressed_data = str(request[imsg])
                if self._use_compression:
                    data = zlib.decompress(compressed_data)
                else:
                    data = compressed_data
                c_run = c_int()
                c_subrun = c_int()
                c_event = c_int()
                c_id = c_int()

                imgdata = larcv.json.sparseimg_from_bson_pybytes(data,
                                        c_run, c_subrun, c_event, c_id )
            except Exception as e:
                self._log.error("Image Data in message part {}\
                                could not be converted: {}".format(imsg,str(e)))
                continue
            self._log.debug("Image[{}] converted: nfeatures={} npts={}"\
                            .format(imsg,imgdata.nfeatures(),
                                    imgdata.pixellist().size()/(imgdata.nfeatures()+1)))

            # get source meta
            # print ("nmeta=",imgdata.meta_v().size())
            srcmeta = imgdata.meta_v().front()
            # print( "srcmeta=",srcmeta.dump())

            # check if correct plane!
            if srcmeta.plane()!=self.plane:
                self._log.debug("Image[{}] meta plane ({}) is not same as worker's ({})!"
                                .format(imsg,srcmeta.plane(),self.plane))
                continue

            # check that same size as previous images
            nfeatures = imgdata.nfeatures()
            npts = imgdata.pixellist().size()/(2+nfeatures)
            imgsize = ( int(srcmeta.rows()), int(srcmeta.cols()),
                        int(nfeatures), int(npts) )
            if len(sizes)==0:
                sizes.append(imgsize)
            elif len(sizes)>0 and imgsize not in sizes:
                self._log.debug("Next image a different size. \
                                    we do not continue batch.")
                self._next_msg_id = imsg
                break
            totalpts += npts
            imgdata_v.append(imgdata)
            frames_used.append(imsg)
            rseid_v.append((c_run.value,c_subrun.value,c_event.value,c_id.value))
            if len(imgdata_v)>=self.batch_size:
                self._next_msg_id = imsg+1
                break


        # convert the images into numpy arrays
        nimgs = len(imgdata_v)
        self._log.debug("converted msgs into batch of {} images. frames={}"
                        .format(nimgs,frames_used))
        np_dtype = np.float32
        # img_batch_np = np.zeros( (nimgs,1,sizes[0][1],sizes[0][0]),
        #                             dtype=np_dtype )
        coord_np  = np.zeros( (totalpts,3), dtype=np.int )
        input_np = np.zeros( (totalpts,1), dtype=np_dtype )
        nfilled = 0
        for iimg,imgdata in enumerate(imgdata_v):
            (rows,cols,nfeatures,npts) = sizes[iimg]
            # print ("size of img", len(imgdata.pixellist()))

            if (len(imgdata.pixellist()) == 0):
                start = 0
                end = 1
                totalpts = end
                coord_np  = np.zeros( (totalpts,3), dtype=np.int )
                input_np = np.zeros( (totalpts,1), dtype=np_dtype )
                coord_np[start:end,0] = 0
                coord_np[start:end,1] = 0
                coord_np[start:end,2]   = iimg
                input_np[start:end,0]   = 10.1
                nfilled = 1

            else:
                data_np = larcv.as_ndarray( imgdata, larcv.msg.kNORMAL )
                start = nfilled
                end   = nfilled+npts
                coord_np[start:end,0:2] = data_np[:,0:2].astype(np.int)
                coord_np[start:end,2]   = iimg
                input_np[start:end,0]   = data_np[:,2]
                nfilled = end
            # print("shape of image: ",img2d_np.shape)

        coord_t  = torch.from_numpy( coord_np ).to(self.device)
        input_t = torch.from_numpy( input_np ).to(self.device)
        with torch.set_grad_enabled(False):
            out_t = self.model(coord_t, input_t, len(imgdata_v))

        out_t = out_t.detach().cpu().numpy()
        # convert from numpy array batch back to sparseimage and messages
        reply = []
        nfilled = 0
        for iimg,(imgshape,imgdata,rseid) in enumerate(zip(sizes,imgdata_v,rseid_v)):
            npts      = imgshape[3]
            start     = nfilled
            end       = start+npts
            nfeatures = 1
            # make numpy array to remake sparseimg
            sparse_np = np.zeros( (npts,2+nfeatures), dtype=np.float32 )
            sparse_np[:,0:2] = coord_np[start:end,0:2]
            sparse_np[:,2]   = out_t[start:end,0]

            outmeta_v = std.vector("larcv::ImageMeta")()
            outmeta_v.push_back( imgdata.meta_v().front() )

            # make the sparseimage object
            sparseimg = larcv.sparseimg_from_ndarray( sparse_np,
                                                      outmeta_v,
                                                      larcv.msg.kNORMAL )

            # convert to bson string
            bson = larcv.json.as_bson_pybytes( sparseimg,
                                               rseid[0], rseid[1], rseid[2], rseid[3] )
            # compress
            if self._use_compression:
                compressed = zlib.compress(bson)
            else:
                compressed = bson

            # add to reply message list
            reply.append(compressed)

            nfilled += npts

        if self._next_msg_id>=nmsgs:
            isfinal = True
            self._still_processing_msg = False
        else:
            isfinal = False
            self._still_processing_msg = True

        self._log.debug("formed reply with {} frames. isfinal={}"
                        .format(len(reply),isfinal))
        return reply,isfinal

    def is_model_loaded(self):
        return self.model is not None