Ejemplo n.º 1
0
def trainDictionary(train_loader, test_loader, sigLen, codeLen, datName,
                    maxEpoch = 2,
                    learnRate = 1,
		                lrdFreq   = 30,
                    learnRateDecay = 1,
                    fistaIters = 50,
                    useL1Loss = False,
                    l1w = 0.5,
                    useCUDA = False,
		                imSave = True,
                    imSavePath = "./",
                    extension = ".png",
                    printFreq = 10,
                    saveFreq  = 100,
                    **kwargs):
    """
    Inputs:
      maxEpoch : number of training epochs (1 epoch = run thru all data)
      learnRate : weight multiplied onto gradient update
      lrdFreq : number of batches between learnRate scaling
      learnRateDecay: scales learnRate every lrdFreq batches, i.e.
        epoch_LR = learnRate*(learnRateDecay**(epoch-1))
      fistaIters : number of iterations to run FISTA when generating epoch codes
      useL1Loss : set to True to penalize sparsity of network output (BUGGY)
      l1w : the L1-norm weight, balances data fidelity with sparsity
      useCUDA : set to true for GPU acceleration
      imSave : boolean determines whether to save images
      dataset : name of the dataset (used in saved image files)
      imSavePath : directory in which multiple images will be saved
      extension : type of image to save (e.g. ".png", ".pdf")
      printFreq : the number of batches between print statements during training
      saveFreq : the number of batches between image saves during training
      kwargs: optional arguments
	      atomImName : name for dictionary atom image other than "dictAtoms"
	      dictInitWeights: initial weights (instead of normal distribution)
    Outputs:
      Dict: the trained dictionary / decoder
      lossHist : loss function history (per batch)
      errHist : reconstruction error history (per batch)
      spstyHist : sparsity history (per batch). i.e. the percent zeros
          achieved during encoding with the dictionary. Encoding is
          performed using FISTA.
    """

    
    # MISC SETUP
    # FISTA:
    fistaOptions = {"returnCodes"  : True,
                    "returnCost"   : False,
                    "returnFidErr" : False}
    # Saving dictionary atom images:
    if "atomImName" in kwargs:
        dictAtoms = kwargs["atomImName"]
    else:
        dictAtoms = datName + "dictAtoms"
    dictAtomImgName = imSavePath + dictAtoms + extension

    # Recordbooks:
    lossHist  = []
    errHist   = []
    spstyHist = []
    
    # INITIALIZE DICTIONARY
    Dict = dictionary(sigLen, codeLen, datName, useCUDA)
    if "dictInitWeights" in kwargs:
        Dict.setWeights(dictInitWeights)
    Dict.normalizeAtoms()
    Dict.zero_grad()
    
    # Loss Function:  .5 ||y-Ax||_2^2 + alpha||x||_1
    mseLoss = nn.MSELoss()
    mseLoss.size_average=True
    if useL1Loss:
      l1Loss = Variable( torch.FloatTensor(1), requires_grad=True)
    else:
      l1Loss = 0

    # Optimizer
    OPT = torch.optim.SGD(Dict.parameters(), lr=learnRate)
    # For learning rate decay
    scheduler = torch.optim.lr_scheduler.StepLR(OPT, step_size = lrdFreq,
                                                gamma = learnRateDecay)
###########################
### DICTIONARY LEARNING ###
###########################
    for it in range(maxEpoch):
        epochLoss      = 0
        epoch_sparsity  = 0
        epoch_rec_error = 0
    
#================================================
    # TRAINING
        numBatch = 0
        for batch_idx, (batch, labels) in enumerate(train_loader):
          gc.collect()
          numBatch += 1
        
          if useCUDA:
              Y = Variable(extractPatches(batch).cuda())
          else:
              Y = Variable(extractPatches(batch))
              
      ## CODE INFERENCE
          fistaOut = FISTA(Y, Dict, l1w, fistaIters, fistaOptions)
          X = fistaOut["codes"]
          gc.collect()
         
      ## FORWARD PASS
          Y_est     = Dict.forward(X)   #try decoding the optimal codes
          
          # loss
          reconErr   = mseLoss(Y_est,Y)
          if useL1Loss:
            l1Loss  = Y_est.norm(1) / X.size(0)
         
      ## BACKWARD PASS
          batchLoss = reconErr + l1w * l1Loss
          batchLoss.backward()
          OPT.step()
          scheduler.step()
          Dict.zero_grad()
          Dict.normalizeAtoms()
          del Dict.maxEig
         
      ## Housekeeping
          sampleLoss      = batchLoss.data[0]
          epochLoss      +=   sampleLoss
          lossHist.append( epochLoss/numBatch )
          
          sample_rec_error = reconErr.data[0]
          epoch_rec_error += sample_rec_error
          errHist.append( epoch_rec_error/numBatch )

          sample_sparsity = ((X.data==0).sum())/X.numel()
          epoch_sparsity  +=  sample_sparsity
          spstyHist.append( epoch_sparsity/ numBatch )

      ## SANITY CHECK:
          # If the codes are practically all-zero, stop fast.
          # TODO: something smarter to do here? Lower L1 weight?
          if np.abs(sample_sparsity - 1.0) < 0.001 :
            print("CODES NEARLY ALL ZERO. SKIP TO NEXT EPOCH.")
            break
 
     ## Print stuff.
     # You may wish to print some figures here too. See bottom of page.
          if batch_idx % printFreq == 0:
              print('Train Epoch: {} [{}/{} ({:.0f}%)]'.format(
                it, batch_idx * len(batch), len(train_loader.dataset),
                100* batch_idx / len(train_loader)))
              print('Loss: {:.6f} \tRecon Err: {:.6f} \tSparsity: {:.6f} '.format(
                     sampleLoss,sample_rec_error,sample_sparsity))
          
          if batch_idx % saveFreq == 0:
              Dict.printAtomImage(dictAtomImgName)
              printProgressFigs(imSavePath, extension, lossHist, errHist, spstyHist)

      ## end "TRAINING" batch-loop
#================================================
    
    ## need one for training, one for testing
        epoch_average_loss = epochLoss/numBatch
        epoch_avg_recErr   = epoch_rec_error/numBatch
        epoch_avg_sparsity = epoch_sparsity/numBatch
    
#        lossHist[it]  = epoch_average_loss
#        errHist[it]   = epoch_avg_recErr
#        spstyHist[it] = epoch_avg_sparsity
        
        print('- - - - - - - - - - - - - - - - - - - - -')
        print('EPOCH ', it + 1,'/',maxEpoch, " STATS")
        print('LOSS = ', epoch_average_loss)
        print('RECON ERR = ',epoch_avg_recErr)
        print('SPARSITY = ',epoch_avg_sparsity)
        print('XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX')
  ## end "EPOCH" loop

    # Convert recordbooks to numpy arrays:
    lossHist  = np.asarray(lossHist)
    errHist   = np.asarray(errHist)
    spstyHist = np.asarray(spstyHist)

    # Save the dictionary/decoder:
#    torch.save(save_dir..'decoder_'..datName..'psz'..pSz..'op'..outPlane..'.t7', decoder) 
    Dict.printAtomImage(dictAtomImgName)
    printProgressFigs(imSavePath, extension, lossHist, errHist, spstyHist)
    return Dict,lossHist,errHist,spstyHist
Ejemplo n.º 2
0
def encoder_class_sanity():
    """
  Tests functionality of encoder class.
  """
    import matplotlib.pyplot as plt
    torch.manual_seed(3)
    device = 'cpu'

    data_size = 15
    code_size = 15
    n_iter = 75
    e = ENCODER(data_size, code_size, device=device, n_iter=n_iter)

    # Create solutions (sparse codes).
    batch_size = 30
    codeHeight = 1.5
    x_true = torch.zeros(batch_size, code_size).to(device)
    for batch in range(batch_size):
        randi = torch.LongTensor(int(code_size / 2)).random_(0, code_size)
        x_true[batch][randi] = codeHeight * torch.rand(
            x_true[batch][randi].size()).to(device)

    # Create the dictionary.
    D = dictionary(data_size, code_size, use_cuda=(device != 'cpu'))

    # Create the observations based on signal model
    #    y = Dx + w
    # where w is white noise
    sigma = 1.5
    data = D(x_true) + sigma * torch.rand(batch_size, data_size).to(device)

    # Dictionary for collecting results.
    forwardMethods = ['ista', 'fista', 'salsa']
    all_loss = {}
    for f in forwardMethods:
        all_loss[f] = {}

    # Run the experiment.
    for meth in forwardMethods:
        print('{} init with {} algorithm--------------------------'.format(
            meth, meth))
        e.initialize_weights_(Dict=D, init_type=meth, mu=1, L1_weight=0.05)
        e.change_encode_algorithm_(meth)
        with torch.no_grad():
            x, loss = e(data, return_loss=True)

        if meth == 'ista':
            all_loss[meth]['loss'] = loss
            all_loss[meth]['x'] = x
        elif meth == 'fista':
            all_loss[meth]['loss'] = loss
            all_loss[meth]['x'] = x
        elif meth == 'salsa':
            all_loss[meth]['loss'] = loss
            all_loss[meth]['x'] = x

    # Visualize results.
    for k in all_loss:
        plt.figure(1)
        plt.clf()
        plt.plot(all_loss[k]['loss'])
        plt.annotate('%0.3f' % all_loss[k]['loss'][-1],
                     xy=(1, all_loss[k]['loss'][-1]),
                     xytext=(8, 0),
                     xycoords=('axes fraction', 'data'),
                     textcoords='offset points')
        plt.title(k + ' loss')
        plt.savefig('./sanity_ims/fixedEncoders/' + k + '_loss.png')

        plt.figure(2)
        plt.clf()
        plt.stem(all_loss[k]['x'][0].numpy())
        plt.title(k + ' solution 0 ')
        plt.savefig('./sanity_ims/fixedEncoders/' + k + '_x.png')

    ###############################################################################
    print(
        'xxxxxxxxxxx Now try a little learning on each architecture xxxxxxxxxxx'
    )

    all_loss = {}
    max_epoch = 150
    for meth in ['ista', 'salsa', 'fista']:
        all_loss[meth] = {}
        print('TRAINING w/ init = {} and alg = {}--------------------------'.
              format(meth, meth))
        e.initialize_weights_(Dict=D, init_type=meth, mu=1, L1_weight=0.05)
        e.change_encode_algorithm_(meth)

        # Set up optimizer.
        opt = torch.optim.Adam(e.parameters(), lr=0.001)

        # Compute labels ("optimal codes").
        with torch.no_grad():
            e.change_n_iter_(1000)
            labels = e(data)
            e.change_n_iter_(10)

        # Loss function
        loss = lambda x: F.mse_loss(x, labels)
        loss_hist = []

        for epoch in range(1, max_epoch):
            opt.zero_grad()
            # Forward Pass
            x = e(data.detach())
            # Backward Pass
            err = loss(x)
            err.backward()
            opt.step()
            loss_hist += [err.item()]

        all_loss[meth]['loss'] = loss_hist
        all_loss[meth]['x'] = x.detach()

    # Visualize results.
    for k in all_loss:
        plt.figure(1)
        plt.clf()
        plt.plot(all_loss[k]['loss'])
        plt.annotate('%0.3f' % all_loss[k]['loss'][-1],
                     xy=(1, all_loss[k]['loss'][-1]),
                     xytext=(8, 0),
                     xycoords=('axes fraction', 'data'),
                     textcoords='offset points')
        plt.title(k + ' Training Loss')
        plt.savefig('./sanity_ims/trainedEncoders/' + k + '_loss.png')

        plt.figure(2)
        plt.clf()
        plt.stem(all_loss[k]['x'][0].numpy())
        plt.title(k + ' solution after training')
        plt.savefig('./sanity_ims/trainedEncoders/' + k + '_x.png')
Ejemplo n.º 3
0
    else:
        return True


################################################
############### DICTIONARY TESTS ###############
# Create a trivial dictionary and double check
# all its methods.
print("Testing dictionary methods ", end="... ")
PASS = True

## Basic multiplication.
M = 40
N = 50
x = Variable(torch.ones(1, N))
decoder = dictionary(M, N, "testDict", False)
decoder.setWeights(torch.ones(M, N))

SDx = torch.sum(decoder(x)).item()
PASS = PASS and testEq(SDx, M * N, "\nError: dictionary not multiplying right")

# Maximum eigenvalue and scaling.
decoder.setWeights(torch.eye(M, N))
decoder.atoms.weight.data[0, 0] = np.sqrt(23.0 / 2)
decoder.scaleWeights(np.sqrt(2))
decoder.getMaxEigVal()
PASS = PASS and testEq(decoder.maxEig, 23,
                       "\nError: maximum eigenvalue computed incorrectly")

# Normalization of atoms/ columns.
decoder.setWeights(torch.rand(M, N))
Ejemplo n.º 4
0
    def initialize_weights_(self,
                            Dict=None,
                            L1_weight=None,
                            init_type='ista',
                            mu=None):
        """
    Fully initializes the encoder using given weight matrices
      (or randomly, if none are given).
    """
        self.init_type = init_type
        # fix-up L1 weights and mu's.
        if self.L1_weight is None:
            if (L1_weight is None):
                print('Using default L1 weight (0.1).')
                self.L1_weight = 0.1
            else:
                self.L1_weight = L1_weight
        if self.mu is None:
            if mu is None:
                self.mu = 1
                if init_type == 'salsa':
                    print('Using default mu value (1).')
            else:
                self.mu = mu
        # If a dictionary is not provided for initialization, initialize randomly.
        if Dict is None:
            Dict = dictionary(self.data_size, self.code_size, use_cuda=False)

        #-------------------------------------
        # Initialize the loss function.
        self.initialize_cvx_lossFcn_(Dict)

        #-------------------------------------
        # Initialize ISTA-style (first order).
        Wd = Dict.getDecWeights().cpu()
        if init_type == 'ista':
            # Get the maximum eigenvalue.
            Dict.getMaxEigVal()
            self.L = Dict.maxEig
            # Initialize.
            self.We.weight.data = (1 / self.L) * (Wd.detach()).t()
            self.S.weight.data = torch.eye(
                Dict.n) - (1 / self.L) * (torch.mm(Wd.t(), Wd)).detach()
            self.thresh = (self.L1_weight / self.L)
            # Set up the nonlinearity, aka soft-thresholding function.

        #-------------------------------------
        # Initialize FISTA-style (first order).
        elif init_type == 'fista':
            # Get the maximum eigenvalue.
            Dict.getMaxEigVal()
            self.L = Dict.maxEig
            # Initialize.
            self.We.weight.data = Wd.detach().t()
            self.thresh = (self.L1_weight / self.L)
            # Set up the nonlinearity, aka soft-thresholding function.

        #---------------------------------------
        # Initialize SALSA-style (second order).
        elif init_type == 'salsa':
            # Initialize matrices.
            self.We.weight.data = Wd.detach().t()
            AA = torch.mm(Wd.t(), Wd).cpu()
            S_weights = (self.mu * torch.eye(Dict.n) + AA).inverse()
            self.S.weight.data = S_weights.detach()
            self.thresh = (self.L1_weight / self.mu)
            # Set up the nonlinearity, aka soft-thresholding function.

        else:
            raise ValueError(
                'Encoders can only be initialized for "ista" and "salsa" like families.'
            )

        #-------------------------------------
        # Print status of the newly created encoder.
#    print('Encoder, threshold, and loss functions are initialized for {}-type algorithms.'.format(init_type))

#-------------------------------------
# Finally, put to device if requested.
        self.We = self.We.to(self.device)
        self.S = self.S.to(self.device)
Ejemplo n.º 5
0
    #     after each training epoch.
    SH.tr_perf = ddict(bw_loss=[],
                       epoch_loss=[],
                       bw_sparsity=[],
                       epoch_sparsity=[],
                       bw_reconErr=[],
                       epoch_reconErr=[])
    SH.te_perf = ddict(loss=[], sparsity=[], reconErr=[])

    #######################################################
    # (2) Set up everything.
    #######################################################
    # 2.b) dictionary.
    print(f"this is code_size:{code_size}")
    print(f"this is dust_size:{data_size}")
    model = dictionary(data_size, code_size, use_cuda=(device != 'cpu'))
    # 2.c) encoder.
    encoder = ENCODER(data_size,
                      code_size,
                      device=device,
                      n_iter=args.encode_iters)
    encoder.change_encode_algorithm_(args.encode_alg)
    setup_encoder = lambda m: encoder.initialize_weights_(
        m, init_type=args.encode_alg, L1_weight=args.L1_weight, mu=args.mu)
    # 2.d) optimizer.
    opt = optimizer_module(model.parameters(), **optParams)

    #######################################################
    # (3) Train model.
    #######################################################
    print('XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX')