Exemplo n.º 1
0
import torch

# Models
from ASPP_ResNet1 import ASPP_ResNet
from caffe_uresnet import UResNet

model1 = ASPP_ResNet(inplanes=16,
                     in_channels=1,
                     num_classes=3,
                     showsizes=False)
model2 = UResNet(inplanes=20, input_channels=1, num_classes=3, showsizes=False)


# Sum the number of trainable parameters in a model.
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Driver function
def main():

    params1 = count_parameters(model1.eval())
    params2 = count_parameters(model2.eval())
    paramDelta = abs(params2 - params1)
    percentDiff = (float(paramDelta) / params2) * 100

    print "Parameters in model1: ", params1
    print "Parameters in model2: ", params2
    print "Absolute difference:  ", paramDelta
    print "Percent difference:   ", percentDiff
Exemplo n.º 2
0
def main():

    global best_prec1
    global writer

    # create model, mark it to run on the GPU
    if GPUMODE:
        model = UResNet(inplanes=20,
                        input_channels=1,
                        num_classes=3,
                        showsizes=False)
        model.cuda(GPUID)
    else:
        model = UResNet(inplanes=20, input_channels=1, num_classes=3)

    # uncomment to dump model
    print "Loaded model: ", model
    # check where model pars are
    # for p in model.parameters():
    #    print p.is_cuda

    # define loss function (criterion) and optimizer
    if GPUMODE:
        criterion = PixelWiseNLLLoss().cuda(GPUID)
    else:
        criterion = PixelWiseNLLLoss()

    # training parameters
    lr = 2.00e-4
    momentum = 0.9
    weight_decay = 1.0e-3

    # training length
    batchsize_train = 10
    batchsize_valid = 2
    start_epoch = 0
    epochs = 1
    start_iter = 0
    num_iters = 10000
    # num_iters    = None # if None
    iter_per_epoch = None  # determined later
    iter_per_valid = 10
    iter_per_checkpoint = 500

    nbatches_per_itertrain = 5
    itersize_train = batchsize_train * nbatches_per_itertrain
    trainbatches_per_print = 1

    nbatches_per_itervalid = 25
    itersize_valid = batchsize_valid * nbatches_per_itervalid
    validbatches_per_print = 5

    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    cudnn.benchmark = True

    # LOAD THE DATASET

    # define configurations
    traincfg = """ThreadDatumFillerTrain: {

  Verbosity:    2
  EnableFilter: false
  RandomAccess: true
  UseThread:    false
  #InputFiles:   ["/mnt/raid0/taritree/ssnet_training_data/train00.root","/mnt/raid0/taritree/ssnet_training_data/train01.root"]
  InputFiles:   ["/media/hdd1/larbys/ssnet_dllee_trainingdata/train00.root","/media/hdd1/larbys/ssnet_dllee_trainingdata/train01.root","/media/hdd1/larbys/ssnet_dllee_trainingdata/train02.root","/media/hdd1/larbys/ssnet_dllee_trainingdata/train03.root"]
  ProcessType:  ["SegFiller"]
  ProcessName:  ["SegFiller"]

  IOManager: {
    Verbosity: 2
    IOMode: 0
    ReadOnlyTypes: [0,0,0]
    ReadOnlyNames: ["wire","segment","ts_keyspweight"]
  }

  ProcessList: {
    SegFiller: {
      # DatumFillerBase configuration
      Verbosity: 2
      ImageProducer:     "wire"
      LabelProducer:     "segment"
      WeightProducer:    "ts_keyspweight"
      # SegFiller configuration
      Channels: [2]
      SegChannel: 2
      EnableMirror: true
      EnableCrop: false
      ClassTypeList: [0,1,2]
      ClassTypeDef: [0,0,0,2,2,2,1,1,1,1]
    }
  }
}
"""
    validcfg = """ThreadDatumFillerValid: {

  Verbosity:    2
  EnableFilter: false
  RandomAccess: true
  UseThread:    false
  #InputFiles:   ["/mnt/raid0/taritree/ssnet_training_data/train02.root"]
  InputFiles:   ["/media/hdd1/larbys/ssnet_dllee_trainingdata/val.root"]
  ProcessType:  ["SegFiller"]
  ProcessName:  ["SegFiller"]

  IOManager: {
    Verbosity: 2
    IOMode: 0
    ReadOnlyTypes: [0,0,0]
    ReadOnlyNames: ["wire","segment","ts_keyspweight"]
  }

  ProcessList: {
    SegFiller: {
      # DatumFillerBase configuration
      Verbosity: 2
      ImageProducer:     "wire"
      LabelProducer:     "segment"
      WeightProducer:    "ts_keyspweight"
      # SegFiller configuration
      Channels: [2]
      SegChannel: 2
      EnableMirror: true
      EnableCrop: false
      ClassTypeList: [0,1,2]
      ClassTypeDef: [0,0,0,2,2,2,1,1,1,1]
    }
  }
}
"""
    with open("segfiller_train.cfg", 'w') as ftrain:
        print >> ftrain, traincfg
    with open("segfiller_valid.cfg", 'w') as fvalid:
        print >> fvalid, validcfg

    iotrain = LArCV1Dataset("ThreadDatumFillerTrain", "segfiller_train.cfg")
    iovalid = LArCV1Dataset("ThreadDatumFillerValid", "segfiller_valid.cfg")
    iotrain.init()
    iovalid.init()
    iotrain.getbatch(batchsize_train)

    NENTRIES = iotrain.io.get_n_entries()
    print "Number of entries in training set: ", NENTRIES

    if NENTRIES > 0:
        iter_per_epoch = NENTRIES / (itersize_train)
        if num_iters is None:
            # we set it by the number of request epochs
            num_iters = (epochs - start_epoch) * NENTRIES
        else:
            epochs = num_iters / NENTRIES
    else:
        iter_per_epoch = 1

    print "Number of epochs: ", epochs
    print "Iter per epoch: ", iter_per_epoch

    with torch.autograd.profiler.profile(enabled=RUNPROFILER) as prof:

        # Resume training option
        if RESUME_FROM_CHECKPOINT:
            checkpoint = torch.load(CHECKPOINT_FILE)
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint['optimizer'])

        for ii in range(start_iter, num_iters):

            adjust_learning_rate(optimizer, ii, lr)
            print "Iter:%d Epoch:%d.%d " % (ii, ii / iter_per_epoch,
                                            ii % iter_per_epoch),
            for param_group in optimizer.param_groups:
                print "lr=%.3e" % (param_group['lr']),
                print

            # train for one epoch
            try:
                train_ave_loss, train_ave_acc = train(
                    iotrain, batchsize_train, model, criterion, optimizer,
                    nbatches_per_itertrain, ii, trainbatches_per_print)
            except Exception, e:
                print "Error in training routine!"
                print e.message
                print e.__class__.__name__
                traceback.print_exc(e)
                break
            print "Iter:%d Epoch:%d.%d train aveloss=%.3f aveacc=%.3f" % (
                ii, ii / iter_per_epoch, ii % iter_per_epoch, train_ave_loss,
                train_ave_acc)

            # evaluate on validation set
            if ii % iter_per_valid == 0:
                try:
                    prec1 = validate(iovalid, batchsize_valid, model,
                                     criterion, nbatches_per_itervalid,
                                     validbatches_per_print, ii)
                except Exception, e:
                    print "Error in validation routine!"
                    print e.message
                    print e.__class__.__name__
                    traceback.print_exc(e)
                    break

                # remember best prec@1 and save checkpoint
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)

                # check point for best model
                if is_best:
                    print "Saving best model"
                    save_checkpoint(
                        {
                            'iter': ii,
                            'epoch': ii / iter_per_epoch,
                            'state_dict': model.state_dict(),
                            'best_prec1': best_prec1,
                            'optimizer': optimizer.state_dict(),
                        }, is_best, -1)

            # periodic checkpoint
            if ii > 0 and ii % iter_per_checkpoint == 0:
                print "saving periodic checkpoint"
                save_checkpoint(
                    {
                        'iter': ii,
                        'epoch': ii / iter_per_epoch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict(),
                    }, False, ii)
Exemplo n.º 3
0
import os, sys, time
import ROOT
from ROOT import std
from larcv import larcv
import numpy as np

import torch
import torch.nn

from caffe_uresnet import UResNet

GPUMODE = True
GPUID = 1

net = UResNet(inplanes=16, input_channels=1, num_classes=3)
if GPUMODE:
    net.cuda(GPUID)

weightfile = "checkpoint.20000th.tar"
checkpoint = torch.load(weightfile)
net.load_state_dict(checkpoint["state_dict"])

print "[ENTER] to end"
raw_input()
Exemplo n.º 4
0
def main():

    global best_prec1
    global writer

    # create model, mark it to run on the GPU
    model = UResNet(inplanes=16,input_channels=1,num_classes=3)
    
    # Resume training option. Fine tuning.
    if not RESUME_FROM_CHECKPOINT:
        checkpoint = torch.load( PRETRAIN_START_FILE )
        model.load_state_dict(checkpoint["state_dict"])

        # reset last layer to output 4 classes
        numclasses = 4 # (bg, shower, track, noise)
        model.conv11 = nn.Conv2d(model.inplanes, numclasses, kernel_size=7, stride=1, padding=3, bias=True )
        n = model.conv11.kernel_size[0] * model.conv11.kernel_size[1] * model.conv11.out_channels
        model.conv11.weight.data.normal_(0, math.sqrt(2. / n))
    else:
        print "Resume training option"
        numclasses = 4 # (bg, shower, track, noise)
        model.conv11 = nn.Conv2d(model.inplanes, numclasses, kernel_size=7, stride=1, padding=3, bias=True )
        checkpoint = torch.load( RESUME_CHECKPOINT_FILE )
        best_prec1 = checkpoint["best_prec1"]
        model.load_state_dict(checkpoint["state_dict"])
    
    if GPUMODE:
        model.cuda(GPUID)

    # uncomment to dump model
    #print "Loaded model: ",model
    # check where model pars are
    #for p in model.parameters():
    #    print p.is_cuda


    # define loss function (criterion) and optimizer
    if GPUMODE:
        criterion = PixelWiseNLLLoss().cuda(GPUID)
    else:
        criterion = PixelWiseNLLLoss()

    # training parameters
    lr = 1.0e-4
    momentum = 0.9
    weight_decay = 1.0e-3

    # training length
    batchsize_train = 10
    batchsize_valid = 8
    start_epoch = 0
    epochs      = 1
    start_iter  = 0
    num_iters   = 10000
    #num_iters    = None # if None
    iter_per_epoch = None # determined later
    iter_per_valid = 10
    iter_per_checkpoint = 500

    nbatches_per_itertrain = 5
    itersize_train         = batchsize_train*nbatches_per_itertrain
    trainbatches_per_print = 1
    
    nbatches_per_itervalid = 25
    itersize_valid         = batchsize_valid*nbatches_per_itervalid
    validbatches_per_print = 5

    optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    #if RESUME_FROM_CHECKPOINT:
    #    optimizer.load_state_dict(checkpoint['optimizer'])
    

    cudnn.benchmark = True

    # LOAD THE DATASET

    # define configurations
    traincfg = """ThreadDatumFillerTrain: {

  Verbosity:    2
  EnableFilter: false
  RandomAccess: true
  UseThread:    false
  #InputFiles:   ["/media/hdd1/larbys/ssnet_cosmic_retraining/cocktail/ssnet_retrain_cocktail_p00.root","/media/hdd1/larbys/ssnet_cosmic_retraining/cocktail/ssnet_retrain_cocktail_p01.root","/media/hdd1/larbys/ssnet_cosmic_retraining/cocktail/ssnet_retrain_cocktail_p02.root"]
  #InputFiles:   ["/cluster/kappa/90-days-archive/wongjiradlab/twongj01/ssnet_training_data/ssnet_retrain_cocktail_p00.root","/cluster/kappa/90-days-archive/wongjiradlab/twongj01/ssnet_training_data/ssnet_retrain_cocktail_p01.root","/cluster/kappa/90-days-archive/wongjiradlab/twongj01/ssnet_training_data/ssnet_retrain_cocktail_p02.root"] 
  InputFiles:   ["/tmp/ssnet_retrain_cocktail_p00.root","/tmp/ssnet_retrain_cocktail_p01.root","/tmp/ssnet_retrain_cocktail_p02.root"] 
  ProcessType:  ["SegFiller"]
  ProcessName:  ["SegFiller"]

  IOManager: {
    Verbosity: 2
    IOMode: 0
    ReadOnlyTypes: []
    ReadOnlyNames: []
  }
    
  ProcessList: {
    SegFiller: {
      # DatumFillerBase configuration
      Verbosity: 2
      ImageProducer:     "adc"
      LabelProducer:     "label"
      WeightProducer:    "weight"
      # SegFiller configuration
      Channels: [0]
      SegChannel: 0
      EnableMirror: false
      EnableCrop: false
      ClassTypeList: [0,2,1,3]
      ClassTypeDef: [0,1,2,3,0,0,0,0,0,0]
    }
  }
}
"""
    validcfg = """ThreadDatumFillerValid: {

  Verbosity:    2
  EnableFilter: false
  RandomAccess: true
  UseThread:    false
  #InputFiles:   ["/media/hdd1/larbys/ssnet_cosmic_retraining/cocktail/ssnet_retrain_cocktail_p03.root"]
  InputFiles:   ["/tmp/ssnet_retrain_cocktail_p03.root"]
  ProcessType:  ["SegFiller"]
  ProcessName:  ["SegFiller"]

  IOManager: {
    Verbosity: 2
    IOMode: 0
    ReadOnlyTypes: []
    ReadOnlyNames: []
  }
    
  ProcessList: {
    SegFiller: {
      # DatumFillerBase configuration
      Verbosity: 2
      ImageProducer:     "adc"
      LabelProducer:     "label"
      WeightProducer:    "weight"
      # SegFiller configuration
      Channels: [0]
      SegChannel: 0
      EnableMirror: false
      EnableCrop: false
      ClassTypeList: [0,2,1,3]
      ClassTypeDef: [0,1,2,3,0,0,0,0,0,0]
    }
  }
}
"""
    with open("segfiller_train.cfg",'w') as ftrain:
        print >> ftrain,traincfg
    with open("segfiller_valid.cfg",'w') as fvalid:
        print >> fvalid,validcfg
    
    iotrain = LArCV1Dataset("ThreadDatumFillerTrain","segfiller_train.cfg" )
    iovalid = LArCV1Dataset("ThreadDatumFillerValid","segfiller_valid.cfg" )
    iotrain.init()
    iovalid.init()
    iotrain.getbatch(batchsize_train)

    NENTRIES = iotrain.io.get_n_entries()
    print "Number of entries in training set: ",NENTRIES

    if NENTRIES>0:
        iter_per_epoch = NENTRIES/(itersize_train)
        if num_iters is None:
            # we set it by the number of request epochs
            num_iters = (epochs-start_epoch)*NENTRIES
        else:
            epochs = num_iters/NENTRIES
    else:
        iter_per_epoch = 1

    print "Number of epochs: ",epochs
    print "Iter per epoch: ",iter_per_epoch

    with torch.autograd.profiler.profile(enabled=RUNPROFILER) as prof:
        

        for ii in range(start_iter, num_iters):

            adjust_learning_rate(optimizer, ii, lr)
            print "Iter:%d Epoch:%d.%d "%(ii,ii/iter_per_epoch,ii%iter_per_epoch),
            for param_group in optimizer.param_groups:
                print "lr=%.3e"%(param_group['lr']),
                print

            # train for one epoch
            try:
                train_ave_loss, train_ave_acc = train(iotrain, batchsize_train, model,
                                                      criterion, optimizer,
                                                      nbatches_per_itertrain, ii, trainbatches_per_print)
            except Exception,e:
                print "Error in training routine!"            
                print e.message
                print e.__class__.__name__
                traceback.print_exc(e)
                break
            print "Iter:%d Epoch:%d.%d train aveloss=%.3f aveacc=%.3f"%(ii,ii/iter_per_epoch,ii%iter_per_epoch,train_ave_loss,train_ave_acc)

            # evaluate on validation set
            if ii%iter_per_valid==0:
                try:
                    prec1 = validate(iovalid, batchsize_valid, model, criterion, nbatches_per_itervalid, validbatches_per_print, ii)
                except Exception,e:
                    print "Error in validation routine!"            
                    print e.message
                    print e.__class__.__name__
                    traceback.print_exc(e)
                    break

                # remember best prec@1 and save checkpoint
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)

                # check point for best model
                if is_best:
                    print "Saving best model"
                    save_checkpoint({
                        'iter':ii,
                        'epoch': ii/iter_per_epoch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer' : optimizer.state_dict(),
                    }, is_best, -1)

            # periodic checkpoint
            if ii>0 and ii%iter_per_checkpoint==0:
                print "saving periodic checkpoint"
                save_checkpoint({
                    'iter':ii,
                    'epoch': ii/iter_per_epoch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer' : optimizer.state_dict(),
                }, False, ii)
Exemplo n.º 5
0
def main(MODEL=MODEL):

    global best_prec1
    # training parameters
    lr = 2.0e-5
    momentum = 0.9
    weight_decay = 1.0e-3
    batchsize_valid = 2

    # Create model -- instantiate on the GPU
    if MODEL == 1:
        if GPUMODE:
            model = ASPP_ResNet(inplanes=16,
                                in_channels=1,
                                num_classes=3,
                                showsizes=False)
        else:
            model = ASPP_ResNet(inplanes=16, in_channels=1, num_classes=3)
    elif MODEL == 2:
        if GPUMODE:
            model = UResNet(inplanes=20,
                            input_channels=1,
                            num_classes=3,
                            showsizes=False)
        else:
            model = UResNet(inplanes=20, input_channels=1, num_classes=3)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    cudnn.benchmark = True

    # Load checkpoint and state dictionary
    # Map the checkpoint file to the CPU -- removes GPU mapping
    map_location = {"cuda:0": "cpu", "cuda:1": "cpu"}
    checkpoint = torch.load(CHECKPOINT_FILE, map_location=map_location)

    # Debugging block:
    # print "Checkpoint file mapped to CPU."
    # print "Press return to load best prediction tensor."
    # raw_input()

    best_prec1 = checkpoint["best_prec1"]
    print "state_dict size: ", len(checkpoint["state_dict"])
    print " "
    for p, t in checkpoint["state_dict"].items():
        print p, t.size()

    # Debugging block:
    # print " "
    # print "Best prediction tensor loaded."
    # print "Press return to load state dictionary."
    # raw_input()

    # Map checkpoint file to the desired GPU
    model.load_state_dict(checkpoint["state_dict"])
    model = model.cuda(GPUID)

    print " "
    print "State dictionary mapped to GPU: ", GPUID
    if MODEL == 1:
        modelString = "ASPP_ResNet"
    elif MODEL == 2:
        modelString = "caffe_uresnet"
    print "Press return to deploy:", modelString
    print "From checkpoint:", CHECKPOINT_FILE
    raw_input()

    # switch to evaluate mode
    model.eval()

    # LOAD THE DATASET
    validcfg = """ThreadDatumFillerValid: {

  Verbosity:    2
  EnableFilter: false
  RandomAccess: true
  UseThread:    false
  #InputFiles:   ["/mnt/raid0/taritree/ssnet_training_data/train02.root"]
  InputFiles:   ["/media/hdd1/larbys/ssnet_dllee_trainingdata/val.root"]
  ProcessType:  ["SegFiller"]
  ProcessName:  ["SegFiller"]

  IOManager: {
    Verbosity: 2
    IOMode: 0
    ReadOnlyTypes: [0,0,0]
    ReadOnlyNames: ["wire","segment","ts_keyspweight"]
  }

  ProcessList: {
    SegFiller: {
      # DatumFillerBase configuration
      Verbosity: 2
      ImageProducer:     "wire"
      LabelProducer:     "segment"
      WeightProducer:    "ts_keyspweight"
      # SegFiller configuration
      Channels: [2]
      SegChannel: 2
      EnableMirror: false
      EnableCrop: false
      ClassTypeList: [0,1,2]
      ClassTypeDef: [0,0,0,2,2,2,1,1,1,1]
    }
  }
}
"""
    with open("segfiller_valid.cfg", 'w') as fvalid:
        print >> fvalid, validcfg

    iovalid = LArCV1Dataset("ThreadDatumFillerValid", "segfiller_valid.cfg")
    iovalid.init()
    iovalid.getbatch(batchsize_valid)

    NENTRIES = iovalid.io.get_n_entries()
    NENTRIES = 10  #debug
    print "Number of entries in input: ", NENTRIES

    ientry = 0
    nbatches = NENTRIES / batchsize_valid
    if NENTRIES % batchsize_valid != 0:
        nbatches += 1

    for ibatch in range(nbatches):
        data = iovalid.getbatch(batchsize_valid)

        # convert to pytorch Variable (with automatic gradient calc.)
        if GPUMODE:
            images_var = torch.autograd.Variable(data.images.cuda(GPUID))
            labels_var = torch.autograd.Variable(data.labels.cuda(GPUID),
                                                 requires_grad=False)
            weight_var = torch.autograd.Variable(data.weight.cuda(GPUID),
                                                 requires_grad=False)
        else:
            images_var = torch.autograd.Variable(data.images)
            labels_var = torch.autograd.Variable(data.labels,
                                                 requires_grad=False)
            weight_var = torch.autograd.Variable(data.weight,
                                                 requires_grad=False)

        # compute output
        output = model(images_var)
        ev_out_wire = outputdata.get_data(larcv.kProductImage2D, "wire")
        wire_t = images_var.data.cpu().numpy()
        weight_t = weight_var.data.cpu().numpy()
        # get predictions from gpu (turns validation routine into images)
        labels_np = output.data.cpu().numpy().astype(np.float32)
        labels_np = 10**labels_np

        for ib in range(batchsize_valid):
            if ientry >= NENTRIES:
                break
            inputmeta.read_entry(ientry)

            ev_meta = inputmeta.get_data(larcv.kProductImage2D, "wire")
            outmeta = ev_meta.Image2DArray()[2].meta()

            img_slice0 = labels_np[ib, 0, :, :]
            nofill_lcv = larcv.as_image2d_meta(img_slice0, outmeta)
            ev_out = outputdata.get_data(larcv.kProductImage2D, "class0")
            ev_out.Append(nofill_lcv)

            img_slice1 = labels_np[ib, 1, :, :]
            fill_lcv = larcv.as_image2d_meta(img_slice1, outmeta)
            ev_out = outputdata.get_data(larcv.kProductImage2D, "class1")
            ev_out.Append(fill_lcv)

            wire_slice = wire_t[ib, 0, :, :]
            wire_out = larcv.as_image2d_meta(wire_slice, outmeta)
            ev_out_wire.Append(wire_out)

            #weight_slice=weight_t[ib,0,:,:]
            #weights_out = larcv.as_image2d_meta(weight_slice,outmeta)
            #ev_out_weights.Append( weights_out )

            outputdata.set_id(1, 1, ibatch * batchsize_valid + ib)
            outputdata.save_entry()
            ientry += 1

    # save results
    outputdata.finalize()
Exemplo n.º 6
0
def main(MODEL=MODEL):

    global best_prec1

    # Create model -- instantiate on the GPU
    if MODEL == 1:
        if GPUMODE:
            model = ASPP_ResNet(inplanes=16,
                                in_channels=1,
                                num_classes=3,
                                showsizes=False)
        else:
            model = ASPP_ResNet(inplanes=16, in_channels=1, num_classes=3)
    elif MODEL == 2:
        if GPUMODE:
            model = UResNet(inplanes=20,
                            input_channels=1,
                            num_classes=3,
                            showsizes=False)
        else:
            model = UResNet(inplanes=20, input_channels=1, num_classes=3)

    # Load checkpoint and state dictionary
    # Map the checkpoint file to the CPU -- removes GPU mapping
    map_location = {"cuda:0": "cpu", "cuda:1": "cpu"}
    checkpoint = torch.load(CHECKPOINT_FILE, map_location=map_location)

    best_prec1 = checkpoint["best_prec1"]
    print "state_dict size: ", len(checkpoint["state_dict"])
    print " "
    for p, t in checkpoint["state_dict"].items():
        print p, t.size()

    # Map checkpoint file to the desired GPU
    model.load_state_dict(checkpoint["state_dict"])
    model = model.cuda(GPUID)

    print "#############################################################"
    print " "
    print "State dictionary mapped to GPU: ", GPUID
    if MODEL == 1:
        modelName = 'ASPP_ResNet'
        modelinfo = 'ASPP_ResNet_Validation_Data_'
    if MODEL == 2:
        modelName = 'caffe_uresnet'
        modelinfo = 'caffe_uresnet_Validation_Data_'
    print "Current model:                  ", modelName
    print " "
    print "Using checkpoint file:          ", CHECKPOINT_FILE
    print " "
    print "Saving to directory:            ", OUTPUT_CVS_DIR
    print "#############################################################"
    print "######### Press return to launch validation routine #########"
    print "#############################################################"
    raw_input()

    # Define loss function (criterion) -- no optimizer for validation
    if GPUMODE:
        criterion = PixelWiseNLLLoss().cuda(GPUID)
    else:
        criterion = PixelWiseNLLLoss()

    ########################
    # Validation Parameters
    ########################
    batchsize_valid = 5
    iter_per_valid = 1
    print_freq = 1
    nbatches_per_itervalid = 2000
    itersize_valid = (batchsize_valid) * (nbatches_per_itervalid)
    validbatches_per_print = 5
    ########################

    cudnn.benchmark = True

    ##########################################
    # Validation routine configuration string
    ##########################################
    validcfg = """ThreadDatumFillerValid: {

  Verbosity:    2
  EnableFilter: false
  RandomAccess: false
  UseThread:    false
  #InputFiles:   ["/mnt/raid0/taritree/ssnet_training_data/train02.root"]
  InputFiles:   ["/media/hdd1/larbys/ssnet_dllee_trainingdata/val.root"]
  ProcessType:  ["SegFiller"]
  ProcessName:  ["SegFiller"]

  IOManager: {
    Verbosity: 2
    IOMode: 0
    ReadOnlyTypes: [0,0,0]
    ReadOnlyNames: ["wire","segment","ts_keyspweight"]
  }

  ProcessList: {
    SegFiller: {
      # DatumFillerBase configuration
      Verbosity: 2
      ImageProducer:     "wire"
      LabelProducer:     "segment"
      WeightProducer:    "ts_keyspweight"
      # SegFiller configuration
      Channels: [2]
      SegChannel: 2
      EnableMirror: true
      EnableCrop: false
      ClassTypeList: [0,1,2]
      ClassTypeDef: [0,0,0,2,2,2,1,1,1,1]
    }
  }
}
"""
    with open("segfiller_valid.cfg", 'w') as fvalid:
        print >> fvalid, validcfg

    iovalid = LArCV1Dataset("ThreadDatumFillerValid", "segfiller_valid.cfg")
    iovalid.init()

    accuracies = []
    for i in range(9):
        accuracies.append(0)

    with torch.autograd.profiler.profile(enabled=RUNPROFILER) as prof:

        # Get total number of pixels from the validation set
        pixel_bucket = validate(iovalid, batchsize_valid, model, criterion,
                                nbatches_per_itervalid, validbatches_per_print,
                                print_freq)
        # Data, batchSize, NN_Model, Loss_Func

    #########################################################
    # Compute the final accuracies over all validation images
    #########################################################
    pixel_acc = []
    for i in range(5):
        pixel_acc.append(0)

    # Pixels for non-zero accuracy calculation
    total_corr_non_zero = pixel_bucket[2] + pixel_bucket[4]
    total_non_zero_pix = pixel_bucket[3] + pixel_bucket[5]

    # Total background accuracy
    pixel_acc[0] = (float(pixel_bucket[0] / (pixel_bucket[1] + 1.0e-8)) * 100)
    # Total track accuracy
    pixel_acc[1] = (float(pixel_bucket[2] / (pixel_bucket[3] + 1.0e-8)) * 100)
    # Total shower Accuracy
    pixel_acc[2] = (float(pixel_bucket[4] / (pixel_bucket[5] + 1.0e-8)) * 100)
    # Non-Zero Pixel Accuracy
    pixel_acc[3] = (float(total_corr_non_zero /
                          (total_non_zero_pix + 1.0e-8)) * 100)
    # Total pixel accuracy
    pixel_acc[4] = (float(pixel_bucket[6] / (pixel_bucket[7] + 1.0e-8)) * 100)
    print "FIN"
    print "PROFILER"
    print "#############################################################"
    print "############### Validation Routine Results: #################"
    print "#############################################################"
    print " "
    print "Accuracies computed over total number of pixels"
    print "in the validation set for model checkpoint:"
    print CHECKPOINT_FILE
    print " "
    print "Batch size:                ", batchsize_valid
    print "Number of batches:         ", nbatches_per_itervalid
    print "Images in validation set:  ", (batchsize_valid *
                                          nbatches_per_itervalid)
    print " "
    print "#############################################################"
    print "###############     Total Pixel Counts:      ################"
    print "#############################################################"
    print " "
    print "total_bkgrnd_correct:         ", pixel_bucket[0]
    print "total_bkgrnd_pix:             ", pixel_bucket[1]
    print "total_trk_correct:            ", pixel_bucket[2]
    print "total_trk_pix:                ", pixel_bucket[3]
    print "total_shwr_correct:           ", pixel_bucket[4]
    print "total_shwr_pix:               ", pixel_bucket[5]
    print "total_corr:                   ", pixel_bucket[6]
    print "total_pix:                    ", pixel_bucket[7]
    print " "
    print "#############################################################"
    print "###############         Accuracies:          ################"
    print "#############################################################"
    print " "
    print "Total background accuracy:    ", pixel_acc[0]
    print "Total non-zero accuracy:      ", pixel_acc[3]
    print "Total shower accuracy:        ", pixel_acc[2]
    print "Total track accuracy:         ", pixel_acc[1]
    print "Total accuracy:               ", pixel_acc[4]
    print " "

    # Create CSV file and write results to it
    filename = OUTPUT_CVS_DIR + modelinfo + '{:%m_%d_%Y}' + '.csv'
    today = pd.Timestamp('today')

    if os.path.isfile(filename.format(today)):
        flag = 'a'
    else:
        flag = 'w'

    with open(filename.format(today), flag) as vD:
        valWriter = csv.writer(vD, delimiter=',')
        valWriter.writerow([' '])
        valWriter.writerow(['Model:                     ', modelName])
        valWriter.writerow(['Checkpoint File:           ', CHECKPOINT_FILE])
        valWriter.writerow(["Total pixel count:"])
        valWriter.writerow(["total_bkgrnd_correct:      ", pixel_bucket[0]])
        valWriter.writerow(["total_bkgrnd_pix:          ", pixel_bucket[1]])
        valWriter.writerow(["total_trk_correct:         ", pixel_bucket[2]])
        valWriter.writerow(["total_trk_pix:             ", pixel_bucket[3]])
        valWriter.writerow(["total_shwr_correct:        ", pixel_bucket[4]])
        valWriter.writerow(["total_shwr_pix:            ", pixel_bucket[5]])
        valWriter.writerow(["total_corr:                ", pixel_bucket[6]])
        valWriter.writerow(["total_pix:                 ", pixel_bucket[7]])
        valWriter.writerow(["Accuracies:"])
        valWriter.writerow(["Total background accuracy: ", pixel_acc[0]])
        valWriter.writerow(["Total non-zero accuracy:   ", pixel_acc[3]])
        valWriter.writerow(["Total shower accuracy:     ", pixel_acc[2]])
        valWriter.writerow(["Total track accuracy:      ", pixel_acc[1]])
        valWriter.writerow(["Total accuracy:            ", pixel_acc[4]])