def bfp_quant(model_name,
              dataset_dir,
              num_classes,
              gpus,
              mantisa_bit,
              exp_bit,
              batch_size=1,
              num_bins=8001,
              eps=0.0001,
              num_workers=2,
              num_examples=10,
              std=None,
              mean=None,
              resize=256,
              crop=224,
              exp_act=None,
              bfp_act_chnl=1,
              bfp_weight_chnl=1,
              bfp_quant=1,
              target_module_list=None,
              act_bins_factor=3,
              fc_bins_factor=4,
              is_online=0):
    # Setting up gpu environment
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)

    # Setting up dataload for evaluation
    valdir = os.path.join(dataset_dir, 'val')
    normalize = transforms.Normalize(mean=mean, std=std)

    # for collect intermediate data use
    collect_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(crop),
            transforms.ToTensor(),
            normalize,
        ])),
                                                 batch_size=num_examples,
                                                 shuffle=False,
                                                 num_workers=num_workers,
                                                 pin_memory=True)
    # for validate the bfp model use
    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(crop),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers,
                                             pin_memory=True)

    # Loading the model
    model, _ = model_factory.get_network(model_name, pretrained=True)
    # Insert the hook to record the intermediate result
    #target_module_list = [nn.BatchNorm2d,nn.Linear] # Insert hook after BN and FC
    model, intern_outputs = Stat_Collector.insert_hook(model,
                                                       target_module_list)
    #model = nn.DataParallel(model)
    model.cuda()
    model.eval()

    # Collect the intermediate result while running number of examples
    logging.info("Collecting the statistics while running image examples....")
    images_statistc = torch.empty((1))
    with torch.no_grad():
        for i_batch, (images, lables) in enumerate(collect_loader):
            images = images.cuda()
            outputs = model(images)
            #print(lables)
            _, predicted = torch.max(outputs.data, 1)
            predicted = predicted.cpu(
            )  # needs to verify if this line can be deleted
            # Collect the input data
            image_shape = images.shape
            images_statistc = torch.reshape(images,
                                            (image_shape[0], image_shape[1],
                                             image_shape[2] * image_shape[3]))
            break

    # Deternmining the optimal exponent of activation and
    # Constructing the distribution for tensorboardX visualization
    logging.info(
        "Determining the optimal exponent by minimizing the KL divergence....")
    start = time.time()
    opt_exp_act_list = []
    max_exp_act_list = []
    # For original input
    opt_exp, max_exp = Utils.find_exp_act(images_statistc,
                                          mantisa_bit,
                                          exp_bit,
                                          group=3,
                                          eps=eps,
                                          bins_factor=act_bins_factor)
    opt_exp_act_list.append(opt_exp)
    max_exp_act_list.append(max_exp)
    sc_layer_num = [7, 10, 17, 20, 23, 30, 33, 36, 39, 42, 49, 52]
    ds_sc_layer_num = [14, 27, 46]
    mobilev2_sc_layer_num = [9, 15, 18, 24, 27, 30, 36, 39, 45, 48]
    for i, intern_output in enumerate(intern_outputs):
        #print ("No.", i, " ", intern_output.out_features.shape)
        #Deternmining the optimal exponent by minimizing the KL_Divergence in channel-wise manner
        if (isinstance(intern_output.m, nn.Conv2d)
                or isinstance(intern_output.m, nn.BatchNorm2d)):
            intern_shape = intern_output.out_features.shape
            #print (intern_shape, "No.", i)
            # assmue internal activation has shape: (batch, channel, height, width)

            if ((model_name == "resnet50") and (i in sc_layer_num)):
                #print ("Before:", intern_shape[1])
                intern_features1 = intern_output.out_features
                intern_features2 = intern_outputs[i - 3].out_features
                intern_features = torch.cat(
                    (intern_features1, intern_features2), 0)
                intern_features = torch.reshape(
                    intern_features, (2 * intern_shape[0], intern_shape[1],
                                      intern_shape[2] * intern_shape[3]))
                #print (intern_features.shape)
                opt_exp, max_exp = Utils.find_exp_act(
                    intern_features,
                    mantisa_bit,
                    exp_bit,
                    group=bfp_act_chnl,
                    eps=eps,
                    bins_factor=act_bins_factor)
                opt_exp_act_list.append(opt_exp)
                max_exp_act_list.append(max_exp)
                #print ("After:", len(opt_exp))
            elif ((model_name == "resnet50") and (i in ds_sc_layer_num)):
                intern_features1 = intern_output.out_features
                intern_features2 = intern_outputs[i - 1].out_features
                intern_features = torch.cat(
                    (intern_features1, intern_features2), 0)
                intern_features = torch.reshape(
                    intern_features, (2 * intern_shape[0], intern_shape[1],
                                      intern_shape[2] * intern_shape[3]))
                #print (intern_features.shape)
                opt_exp, max_exp = Utils.find_exp_act(
                    intern_features,
                    mantisa_bit,
                    exp_bit,
                    group=bfp_act_chnl,
                    eps=eps,
                    bins_factor=act_bins_factor)
                #print ("Current shape", np.shape(opt_exp), " No.", i)
                #print ("Previous shape", np.shape(opt_exp_act_list[i]), " No.", i-1)
                opt_exp_act_list.append(opt_exp)
                max_exp_act_list.append(max_exp)
                opt_exp_act_list[i] = (opt_exp)  #Starting from 1
                max_exp_act_list[i] = (max_exp)
            elif ((model_name == "mobilenetv2")
                  and (i in mobilev2_sc_layer_num)):
                intern_features1 = intern_output.out_features
                intern_features2 = intern_outputs[i - 3].out_features
                intern_features = torch.cat(
                    (intern_features1, intern_features2), 0)
                intern_features = torch.reshape(
                    intern_features, (2 * intern_shape[0], intern_shape[1],
                                      intern_shape[2] * intern_shape[3]))
                #print (intern_features.shape)
                opt_exp, max_exp = Utils.find_exp_act(
                    intern_features,
                    mantisa_bit,
                    exp_bit,
                    group=bfp_act_chnl,
                    eps=eps,
                    bins_factor=act_bins_factor)
                opt_exp_act_list.append(opt_exp)  ##changed
                max_exp_act_list.append(max_exp)
            else:
                intern_features = torch.reshape(
                    intern_output.out_features,
                    (intern_shape[0], intern_shape[1],
                     intern_shape[2] * intern_shape[3]))
                opt_exp, max_exp = Utils.find_exp_act(
                    intern_features,
                    mantisa_bit,
                    exp_bit,
                    group=bfp_act_chnl,
                    eps=eps,
                    bins_factor=act_bins_factor)
                opt_exp_act_list.append(opt_exp)  ##changed
                max_exp_act_list.append(max_exp)
                # ploting the distribution
                #writer.add_histogram("layer%d" % (i), intern_output.out_features.cpu().data.numpy(), bins='auto')
                quant_tensor = BFPActivation.transform_activation_offline(
                    intern_output.out_features, exp_bit, mantisa_bit, max_exp)
                #writer.add_histogram("layer%d" % (i), quant_tensor.cpu().data.numpy(), bins='auto')
                quant_tensor = BFPActivation.transform_activation_offline(
                    intern_output.out_features, exp_bit, mantisa_bit, opt_exp)
                writer.add_histogram("layer%d" % (i),
                                     quant_tensor.cpu().data.numpy(),
                                     bins='auto')

            #print (np.shape(opt_exp), " No.", i)
        elif (isinstance(intern_output.m, nn.Linear)):
            intern_shape = intern_output.out_features.shape
            opt_exp, max_exp = Utils.find_exp_fc(intern_output.out_features,
                                                 mantisa_bit,
                                                 exp_bit,
                                                 block_size=intern_shape[1],
                                                 eps=eps,
                                                 bins_factor=fc_bins_factor)
            #print ("shape of fc exponent:", np.shape(opt_exp))
            opt_exp_act_list.append(max_exp)
            max_exp_act_list.append(max_exp)
        else:
            intern_shape = intern_output.in_features[0].shape
            intern_features = torch.reshape(
                intern_output.in_features[0],
                (intern_shape[0], intern_shape[1],
                 intern_shape[2] * intern_shape[3]))
            opt_exp, max_exp = Utils.find_exp_act(intern_features,
                                                  mantisa_bit,
                                                  exp_bit,
                                                  group=bfp_act_chnl,
                                                  eps=eps,
                                                  bins_factor=act_bins_factor)
            opt_exp_act_list.append(opt_exp)
            max_exp_act_list.append(max_exp)

        #logging.info("The internal shape: %s" % ((str)(intern_output.out_features.shape)))
    end = time.time()
    logging.info(
        "It took %f second to determine the optimal shared exponent for each block."
        % ((float)(end - start)))
    logging.info("The shape of collect exponents: %s" %
                 ((str)(np.shape(opt_exp_act_list))))

    # Building a BFP model by insert BFPAct and BFPWeiht based on opt_exp_act_list
    torch.cuda.empty_cache()
    if (exp_act == 'kl'):
        exp_act_list = opt_exp_act_list
    else:
        exp_act_list = max_exp_act_list
    if (is_online == 1):
        model_name = "br_" + model_name
    bfp_model, weight_exp_list = model_factory.get_network(
        model_name,
        pretrained=True,
        bfp=(bfp_quant == 1),
        group=bfp_weight_chnl,
        mantisa_bit=mantisa_bit,
        exp_bit=exp_bit,
        opt_exp_act_list=exp_act_list)

    writer.close()
def bfp_quant(model_name, dataset_dir, num_classes, gpus, mantisa_bit, exp_bit, batch_size=1, 
                num_bins=8001, eps=0.0001, num_workers=2, num_examples=10, std=None, mean=None,
                resize=256, crop=224, exp_act=None, bfp_act_chnl=1, bfp_weight_chnl=1, bfp_quant=1,
                target_module_list=None, act_bins_factor=3, fc_bins_factor=4, is_online=0):
    # Setting up gpu environment
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)


    # Setting up dataload for evaluation
    valdir = os.path.join(dataset_dir, 'val')
    normalize = transforms.Normalize(mean=mean,
                                     std=std)

    
    train_dataloader = DataLoader(VideoDataset(dataset='ucf101', split='train',clip_len=16, model_name=model_name), batch_size=batch_size, shuffle=True, 
                            num_workers=4)
    val_dataloader   = DataLoader(VideoDataset(dataset='ucf101', split='val',  clip_len=16, model_name=model_name), batch_size=num_examples, num_workers=4)
    test_dataloader  = DataLoader(VideoDataset(dataset='ucf101', split='test', clip_len=16, model_name=model_name), batch_size=batch_size, num_workers=4)

    # # for collect intermediate data use
    # collect_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(resize),
    #         transforms.CenterCrop(crop),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])),
    #     batch_size=num_examples, shuffle=False,
    #     num_workers=num_workers, pin_memory=True)
    # # for validate the bfp model use
    # val_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(resize),
    #         transforms.CenterCrop(crop),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])),
    #     batch_size=batch_size, shuffle=False,
    #     num_workers=num_workers, pin_memory=True)

    if (bfp_quant == 1):
        # Loading the model
        model, _ = model_factory_3d.get_network(model_name, pretrained=True)
        # Insert the hook to record the intermediate result
        #target_module_list = [nn.BatchNorm2d,nn.Linear] # Insert hook after BN and FC
        model, intern_outputs = Stat_Collector.insert_hook(model, target_module_list)
        #model = nn.DataParallel(model)
        model.cuda()
        model.eval()
        

        # Collect the intermediate result while running number of examples
        logging.info("Collecting the statistics while running image examples....")
        images_statistc = torch.empty((1))
        with torch.no_grad():
            for i_batch, (images, lables) in enumerate(val_dataloader):
                images = images.cuda()
                outputs = model(images)
                #print(lables)
                _, predicted = torch.max(outputs.data, 1)
                predicted = predicted.cpu() # needs to verify if this line can be deleted  
                # Collect the input data
                image_shape = images.shape
                images_statistc = torch.reshape(images, 
                                        (image_shape[0], image_shape[1], image_shape[2], image_shape[3]*image_shape[4]))
                break
        
        # Deternmining the optimal exponent of activation and
        # Constructing the distribution for tensorboardX visualization
        logging.info("Determining the optimal exponent by minimizing the KL divergence....")
        start = time.time()
        opt_exp_act_list = []
        max_exp_act_list = []
        # For original input
        opt_exp, max_exp = Utils.find_exp_act_3d(images_statistc, mantisa_bit, exp_bit, group=3, eps=eps, bins_factor=act_bins_factor)
        opt_exp_act_list.append(opt_exp)
        max_exp_act_list.append(max_exp)
        sc_layer = []
        ds_sc_layer = []
        if model_name == "r3d_18":
            sc_layer = [2, 4, 9, 14, 19]
            ds_sc_layer = [6, 7, 11, 12, 16, 17]
        else:
            sc_layer = [2, 4, 6, 11, 13, 15, 20, 22, 24, 26, 28, 33, 35]
            ds_sc_layer = [8, 9, 17, 18, 30, 31]
        for i, intern_output in enumerate(intern_outputs):
            #Deternmining the optimal exponent by minimizing the KL_Divergence in channel-wise manner
            print ("i-th", i, "  shape:", intern_output.out_features.shape, " name:", intern_output.m)
            if (isinstance(intern_output.m, nn.Conv3d) or isinstance(intern_output.m, nn.BatchNorm3d)):
                if ((model_name=="r3d") and (i in sc_layer)):
                    intern_features1 = intern_output.out_features
                    intern_features2 = intern_outputs[i-2].out_features
                    intern_features = torch.cat((intern_features1, intern_features2), 0)
                    intern_features = torch.reshape(intern_features,
                                    (2*intern_shape[0], intern_shape[1], intern_shape[2], intern_shape[3]*intern_shape[4]))
                    opt_exp, max_exp = Utils.find_exp_act_3d(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                    print ("i-th", i, "  length:", len(opt_exp))
                    opt_exp_act_list.append(opt_exp) ##changed
                    max_exp_act_list.append(max_exp)
                elif ((model_name=="r3d") and (i in ds_sc_layer)):
                    intern_features1 = intern_output.out_features
                    if ((i+1) in ds_sc_layer):
                        intern_features2 = intern_outputs[i+1].out_features
                    else:
                        continue # Use the same exp as previous layer
                    intern_features = torch.cat((intern_features1, intern_features2), 0)
                    intern_features = torch.reshape(intern_features,
                                    (2*intern_shape[0], intern_shape[1], intern_shape[2], intern_shape[3]*intern_shape[4]))
                    opt_exp, max_exp = Utils.find_exp_act_3d(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                    print ("i-th", i, "  length:", len(opt_exp))
                    opt_exp_act_list.append(opt_exp) ##changed
                    max_exp_act_list.append(max_exp)
                else:
                    intern_shape = intern_output.out_features.shape
                    print (intern_shape)
                    intern_features = torch.reshape(intern_output.out_features,
                                    (intern_shape[0], intern_shape[1], intern_shape[2], intern_shape[3]*intern_shape[4]))
                    opt_exp, max_exp = Utils.find_exp_act_3d(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                    print ("i-th", i, "  length:", len(opt_exp))
                    opt_exp_act_list.append(opt_exp) ##changed
                    max_exp_act_list.append(max_exp)
            elif (isinstance(intern_output.m, nn.Linear)):
                intern_shape = intern_output.out_features.shape
                opt_exp, max_exp = Utils.find_exp_fc(intern_output.out_features, mantisa_bit, exp_bit, block_size = intern_shape[1], eps=eps, bins_factor=fc_bins_factor)
                #print ("shape of fc exponent:", np.shape(opt_exp))
                opt_exp_act_list.append(max_exp)
                max_exp_act_list.append(max_exp)
            else:
                pass
                '''
                intern_shape = intern_output.in_features[0].shape
                intern_features = torch.reshape(intern_output.in_features[0], 
                                    (intern_shape[0], intern_shape[1], intern_shape[2]*intern_shape[3]))
                opt_exp, max_exp = Utils.find_exp_act(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                opt_exp_act_list.append(opt_exp)
                max_exp_act_list.append(max_exp)
                '''
                
            #logging.info("The internal shape: %s" % ((str)(intern_output.out_features.shape)))
        end = time.time()
        logging.info("It took %f second to determine the optimal shared exponent for each block." % ((float)(end-start)))
        logging.info("The shape of collect exponents: %s" % ((str)(np.shape(opt_exp_act_list))))

        # Building a BFP model by insert BFPAct and BFPWeiht based on opt_exp_act_list
        torch.cuda.empty_cache() 
        if (exp_act=='kl'):
            exp_act_list = opt_exp_act_list
        else:
            exp_act_list = max_exp_act_list
    else:
       exp_act_list = None 
    bfp_model, weight_exp_list = model_factory_3d.get_network(model_name, pretrained=True, bfp=(bfp_quant==1), group=bfp_weight_chnl, mantisa_bit=mantisa_bit, 
                exp_bit=exp_bit, opt_exp_act_list=exp_act_list, is_online=is_online, exp_act=exp_act)

    dynamic_bfp_model, dynamic_weight_exp_list = model_factory_3d.get_network(model_name, pretrained=True, bfp=(bfp_quant==1), group=bfp_weight_chnl, mantisa_bit=mantisa_bit, 
                exp_bit=exp_bit, opt_exp_act_list=exp_act_list, is_online=1, exp_act=exp_act)

    #torch.cuda.empty_cache() 
    confusion_matrix = torch.zeros(num_classes, num_classes)
    dynamic_confusion_matrix = torch.zeros(num_classes, num_classes)
    logging.info("Evaluation Block Floating Point quantization....")
    correct = 0
    total = 0
    dynamic_correct = 0
    dynamic_total = 0
    if ((model_name != "br_mobilenetv2") or (model_name != "mobilenetv2")):
        bfp_model = nn.DataParallel(bfp_model)
    bfp_model.cuda()
    bfp_model.eval()
    if ((model_name != "br_mobilenetv2") or (model_name != "mobilenetv2")):
        dynamic_bfp_model = nn.DataParallel(dynamic_bfp_model)
    dynamic_bfp_model.cuda()
    dynamic_bfp_model.eval()
    with torch.no_grad():
        for i_batch, (images, lables) in enumerate(test_dataloader):
            images = images.cuda()
            outputs = bfp_model(images)
            probs = nn.Softmax(dim=1)(outputs)
            _, predicted = torch.max(probs, 1)
            predicted = predicted.cpu()
            dynamic_outputs = dynamic_bfp_model(images)
            dynamic_probs = nn.Softmax(dim=1)(dynamic_outputs)
            _,dynamic_predicted = torch.max(dynamic_probs, 1)
            dynamic_predicted = dynamic_predicted.cpu()
            for t, p in zip(lables.view(-1), predicted.view(-1)):
              confusion_matrix[t.long(), p.long()] += 1
            for t, p in zip(lables.view(-1), dynamic_predicted.view(-1)):
              dynamic_confusion_matrix[t.long(), p.long()] += 1
            total += lables.size(0)
            correct += (predicted == lables).sum().item()
            logging.info("Current images: %d" % (total))
            #if (total > 2000):
            #    break
    logging.info("Total: %d, Accuracy: %f " % (total, float(correct / total)))
    logging.info("Floating conv weight and fc(act and weight), act bins_factor is %d,fc bins_factor is %d, exp_opt for act is %s, act group is %d"%(act_bins_factor, fc_bins_factor, exp_act, bfp_act_chnl))    
    print ("Per class accuracy:", (confusion_matrix.diag()/confusion_matrix.sum(1)))
    print ("Per class dynamic accuracy:", (dynamic_confusion_matrix.diag()/dynamic_confusion_matrix.sum(1)))
    torch.save(confusion_matrix, "static_bfp_r3d.pt")
    torch.save(dynamic_confusion_matrix, "dynamic_bfp_r3d.pt")
    writer.close()