Exemplo n.º 1
0
def compute_acc(Y, log_dist_a, n_se=2, return_er=False):
    """compute the accuracy of the prediction, over time
    - optionally return the standard error
    - assume n_action == y_dim + 1

    Parameters
    ----------
    Y : 3d tensor
        [n_examples, T_total, y_dim]
    log_dist_a : 3d tensor
        [n_examples, T_total, n_action]
    n_se : int
        number of SE
    return_er : bool
        whether to return SEs

    Returns
    -------
    1d array(s)
        stats for state prediction accuracy

    """
    if type(Y) is not np.ndarray:
        Y = to_np(Y)
    # argmax the action distribution (don't know unit included)
    argmax_dist_a = np.argmax(log_dist_a, axis=2)
    # argmax the targets one hot vecs
    argmax_Y = np.argmax(Y, axis=2)
    # compute matches
    corrects = argmax_Y == argmax_dist_a
    # compute stats across trials
    acc_mu_, acc_er_ = compute_stats(corrects, axis=0, n_se=n_se)
    if return_er:
        return acc_mu_, acc_er_
    return acc_mu_
Exemplo n.º 2
0
def compute_cell_memory_similarity(C,
                                   V,
                                   inpt,
                                   leak,
                                   comp,
                                   kernel='cosine',
                                   recall_func='LCA'):

    n_examples, n_timepoints, n_dim = np.shape(C)
    n_memories = len(V[0])
    # prealloc
    sim_raw = np.zeros((n_examples, n_timepoints, n_memories))
    sim_lca = np.zeros((n_examples, n_timepoints, n_memories))
    for i in range(n_examples):
        # compute similarity
        for t in range(n_timepoints):
            # compute raw similarity
            sim_raw[i, t, :] = to_np(
                compute_similarities(to_pth(C[i, t]), V[i], kernel))
            # compute LCA similarity
            sim_lca[i, t, :] = transform_similarities(to_pth(sim_raw[i, t, :]),
                                                      recall_func,
                                                      leak=to_pth(leak[i, t]),
                                                      comp=to_pth(comp[i, t]),
                                                      w_input=to_pth(inpt[i,
                                                                          t]))
    return sim_raw, sim_lca
Exemplo n.º 3
0
def compute_mistake(Y, log_dist_a, n_se=2, return_er=False):
    if type(Y) is not np.ndarray:
        Y = to_np(Y)
    # argmax the action distribution (don't know unit included)
    argmax_dist_a = np.argmax(log_dist_a, axis=2)
    # argmax the targets one hot vecs
    argmax_Y = np.argmax(Y, axis=2)
    # compute the difference
    diff = argmax_Y != argmax_dist_a
    # get don't knows
    dk = _compute_dk(log_dist_a)
    # mistake := different from target and not dk
    mistakes = np.logical_and(diff, ~dk)
    # compute stats across trials
    mis_mu_, mis_er_ = compute_stats(mistakes, axis=0, n_se=n_se)
    if return_er:
        return mis_mu_, mis_er_
    return mis_mu_
Exemplo n.º 4
0
def predict(predict_loader,
            model,model1,model2,model3):
    
    global logger
    global pred_minib_counter
    
    m = nn.Sigmoid()
    model.eval()
    model1.eval()
    model2.eval()
    model3.eval()

    temp_df = pd.DataFrame(columns = ['ImageId','EncodedPixels'])
    
    with tqdm.tqdm(total=len(predict_loader)) as pbar:
        for i, (input, target, or_resl, target_resl, img_ids) in enumerate(predict_loader):
            
            # reshape to PyTorch format
            input = input.permute(0,3,1,2).contiguous().float().cuda(async=True)
            input_var = torch.autograd.Variable(input, volatile=True)
            
            # compute output
            output = model(input_var)
            output1 = model1(input_var)
            output2 = model2(input_var)
            output3 = model3(input_var)
            
            for k,(pred_mask,pred_mask1,pred_mask2,pred_mask3) in enumerate(zip(output,output1,output2,output3)):

                or_w = or_resl[0][k]
                or_h = or_resl[1][k]
                
                print(or_w,or_h)
                
                mask_predictions = []
                energy_predictions = []
                
                # for pred_msk in [pred_mask,pred_mask1,pred_mask2,pred_mask3]:
                for pred_msk in [pred_mask]:
                    _,__ = calculate_energy(pred_msk,or_h,or_w)
                    mask_predictions.append(_)
                    energy_predictions.append(__)
                   
                avg_mask = np.asarray(mask_predictions).mean(axis=0)
                avg_energy = np.asarray(energy_predictions).mean(axis=0)
                imsave('../examples/mask_{}.png'.format(img_ids[k]),avg_mask.astype('uint8'))
                imsave('../examples/energy_{}.png'.format(img_ids[k]),avg_energy.astype('uint8'))
                
                labels = wt_seeds(avg_mask,
                                  avg_energy,
                                  args.ths)  
                
                labels_seed = cv2.applyColorMap((labels / labels.max() * 255).astype('uint8'), cv2.COLORMAP_JET)                  
                imsave('../examples/labels_{}.png'.format(img_ids[k]),labels_seed)

                if args.tensorboard_images:
                    info = {
                        'images': to_np(input),
                        'labels_wt': np.expand_dims(labels_seed,axis=0),
                        'pred_mask_fold0': np.expand_dims(mask_predictions[0],axis=0),
                        'pred_mask_fold1': np.expand_dims(mask_predictions[1],axis=0),
                        'pred_mask_fold2': np.expand_dims(mask_predictions[2],axis=0),
                        'pred_mask_fold3': np.expand_dims(mask_predictions[3],axis=0),
                        'pred_energy_fold0': np.expand_dims(energy_predictions[0],axis=0),
                        'pred_energy_fold1': np.expand_dims(energy_predictions[1],axis=0),
                        'pred_energy_fold2': np.expand_dims(energy_predictions[2],axis=0),
                        'pred_energy_fold3': np.expand_dims(energy_predictions[3],axis=0),
                    }
                    for tag, images in info.items():
                        logger.image_summary(tag, images, pred_minib_counter)

                pred_minib_counter += 1
                
                wt_areas = []
                for j,label in enumerate(np.unique(labels)):
                    if j == 0:
                        # pass the background
                        pass
                    else:
                        wt_areas.append((labels == label) * 1)
               
                for wt_area in wt_areas:
                    append_df = pd.DataFrame(columns = ['ImageId','EncodedPixels'])
                    append_df['ImageId'] = [img_ids[k]]
                    append_df['EncodedPixels'] = [' '.join(map(str, rle_encoding(wt_area))) ]
                    
                    temp_df = temp_df.append(append_df)
            
            pbar.update(1)            

    return temp_df
Exemplo n.º 5
0
def validate(val_loader,
             model,
             criterion,
             scheduler,
             source_resl,
             target_resl):
                                
    global valid_minib_counter
    global logger
    
    # scheduler.batch_step()    
    
    batch_time = AverageMeter()
    losses = AverageMeter()
    f1_scores = AverageMeter()
    map_scores_wt = AverageMeter()
    map_scores_wt_seed = AverageMeter()    
    
    # switch to evaluate mode
    model.eval()

    # sigmoid for f1 calculation and illustrations
    m = nn.Sigmoid()      
    
    end = time.time()
    for i, (input, target, or_resl, target_resl,img_sample) in enumerate(val_loader):
        
        # permute to pytorch format
        input = input.permute(0,3,1,2).contiguous().float().cuda(async=True)
        # take only mask and boundary at first
        target = target[:,:,:,0:args.channels].permute(0,3,1,2).contiguous().float().cuda(async=True)

        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        # compute output
        output = model(input_var)
                                            
        loss = criterion(output, target_var)
        
        # go over all of the predictions
        # apply the transformation to each mask
        # calculate score for each of the images
        
        averaged_maps_wt = []
        averaged_maps_wt_seed = []
        y_preds_wt = []
        y_preds_wt_seed = []
        energy_levels = []
            
        for j,pred_output in enumerate(output):
            or_w = or_resl[0][j]
            or_h = or_resl[1][j]
            
            # I keep only the latest preset
            
            pred_mask = m(pred_output[0,:,:]).data.cpu().numpy()
            pred_mask1 = m(pred_output[1,:,:]).data.cpu().numpy()
            pred_mask2 = m(pred_output[2,:,:]).data.cpu().numpy()
            pred_mask3 = m(pred_output[3,:,:]).data.cpu().numpy()
            pred_mask0 = m(pred_output[4,:,:]).data.cpu().numpy()
            pred_border = m(pred_output[5,:,:]).data.cpu().numpy()
            # pred_distance = m(pred_output[5,:,:]).data.cpu().numpy()            
            pred_vector0 = pred_output[6,:,:].data.cpu().numpy()
            pred_vector1 = pred_output[7,:,:].data.cpu().numpy()             

            pred_mask = cv2.resize(pred_mask, (or_h,or_w), interpolation=cv2.INTER_LINEAR)
            pred_mask1 = cv2.resize(pred_mask1, (or_h,or_w), interpolation=cv2.INTER_LINEAR)
            pred_mask2 = cv2.resize(pred_mask2, (or_h,or_w), interpolation=cv2.INTER_LINEAR)
            pred_mask3 = cv2.resize(pred_mask3, (or_h,or_w), interpolation=cv2.INTER_LINEAR)
            pred_mask0 = cv2.resize(pred_mask0, (or_h,or_w), interpolation=cv2.INTER_LINEAR)
            # pred_distance = cv2.resize(pred_distance, (or_h,or_w), interpolation=cv2.INTER_LINEAR)
            pred_border = cv2.resize(pred_border, (or_h,or_w), interpolation=cv2.INTER_LINEAR)
            pred_vector0 = cv2.resize(pred_vector0, (or_h,or_w), interpolation=cv2.INTER_LINEAR) 
            pred_vector1 = cv2.resize(pred_vector1, (or_h,or_w), interpolation=cv2.INTER_LINEAR)             
            
            # predict average energy by summing all the masks up 
            pred_energy = (pred_mask+pred_mask1+pred_mask2+pred_mask3+pred_mask0)/5*255
            pred_mask_255 = np.copy(pred_mask) * 255            

            # read the original masks for metric evaluation
            mask_glob = glob.glob('../data/stage1_train/{}/masks/*.png'.format(img_sample[j]))
            gt_masks = imread_collection(mask_glob).concatenate()

            # simple wt
            y_pred_wt = wt_baseline(pred_mask_255, args.ths)
            
            # wt with seeds
            y_pred_wt_seed = wt_seeds(pred_mask_255,pred_energy,args.ths)            
            
            map_wt = calculate_ap(y_pred_wt, gt_masks)
            map_wt_seed = calculate_ap(y_pred_wt_seed, gt_masks)
            
            averaged_maps_wt.append(map_wt[1])
            averaged_maps_wt_seed.append(map_wt_seed[1])

            # apply colormap for easier tracking
            y_pred_wt = cv2.applyColorMap((y_pred_wt / y_pred_wt.max() * 255).astype('uint8'), cv2.COLORMAP_JET) 
            y_pred_wt_seed = cv2.applyColorMap((y_pred_wt_seed / y_pred_wt_seed.max() * 255).astype('uint8'), cv2.COLORMAP_JET)  
            
            y_preds_wt.append(y_pred_wt)
            y_preds_wt_seed.append(y_pred_wt_seed)
            energy_levels.append(pred_energy)
            
            # print('MAP for sample {} is {}'.format(img_sample[j],m_ap))
        
        y_preds_wt = np.asarray(y_preds_wt)
        y_preds_wt_seed = np.asarray(y_preds_wt_seed)
        energy_levels = np.asarray(energy_levels)
        
        averaged_maps_wt = np.asarray(averaged_maps_wt).mean()
        averaged_maps_wt_seed = np.asarray(averaged_maps_wt_seed).mean()

        #============ TensorBoard logging ============#                                            
        if args.tensorboard_images:
            if i == 0:
                if args.channels == 5:
                    info = {
                        'images': to_np(input[:2,:,:,:]),
                        'gt_mask': to_np(target[:2,0,:,:]),
                        'gt_mask1': to_np(target[:2,1,:,:]),
                        'gt_mask2': to_np(target[:2,2,:,:]),
                        'gt_mask3': to_np(target[:2,3,:,:]), 
                        'gt_mask0': to_np(target[:2,4,:,:]),
                        'pred_mask': to_np(m(output.data[:2,0,:,:])),
                        'pred_mask1': to_np(m(output.data[:2,1,:,:])),
                        'pred_mask2': to_np(m(output.data[:2,2,:,:])),
                        'pred_mask3': to_np(m(output.data[:2,3,:,:])),
                        'pred_mask0': to_np(m(output.data[:2,4,:,:])),
                        'pred_energy': energy_levels[:2,:,:], 
                        'pred_wt': y_preds_wt[:2,:,:],
                        'pred_wt_seed': y_preds_wt_seed[:2,:,:,:],
                    }
                    for tag, images in info.items():
                        logger.image_summary(tag, images, valid_minib_counter)                   
                elif args.channels == 6:
                    info = {
                        'images': to_np(input[:2,:,:,:]),
                        'gt_mask': to_np(target[:2,0,:,:]),
                        'gt_mask1': to_np(target[:2,1,:,:]),
                        'gt_mask2': to_np(target[:2,2,:,:]),
                        'gt_mask3': to_np(target[:2,3,:,:]), 
                        'gt_mask0': to_np(target[:2,4,:,:]),
                        'gt_mask_distance': to_np(target[:2,5,:,:]),
                        'pred_mask': to_np(m(output.data[:2,0,:,:])),
                        'pred_mask1': to_np(m(output.data[:2,1,:,:])),
                        'pred_mask2': to_np(m(output.data[:2,2,:,:])),
                        'pred_mask3': to_np(m(output.data[:2,3,:,:])),
                        'pred_mask0': to_np(m(output.data[:2,4,:,:])),
                        'pred_distance': to_np(m(output.data[:2,5,:,:])),
                        'pred_energy': energy_levels[:2,:,:], 
                        'pred_wt': y_preds_wt[:2,:,:],
                        'pred_wt_seed': y_preds_wt_seed[:2,:,:,:],
                    }
                    for tag, images in info.items():
                        logger.image_summary(tag, images, valid_minib_counter)
                elif args.channels == 7:
                    info = {
                        'images': to_np(input[:2,:,:,:]),
                        'gt_mask': to_np(target[:2,0,:,:]),
                        'gt_mask1': to_np(target[:2,1,:,:]),
                        'gt_mask2': to_np(target[:2,2,:,:]),
                        'gt_mask3': to_np(target[:2,3,:,:]), 
                        'gt_mask0': to_np(target[:2,4,:,:]),
                        'gt_mask_distance': to_np(target[:2,5,:,:]),
                        'gt_border': to_np(target[:2,6,:,:]),                        
                        'pred_mask': to_np(m(output.data[:2,0,:,:])),
                        'pred_mask1': to_np(m(output.data[:2,1,:,:])),
                        'pred_mask2': to_np(m(output.data[:2,2,:,:])),
                        'pred_mask3': to_np(m(output.data[:2,3,:,:])),
                        'pred_mask0': to_np(m(output.data[:2,4,:,:])),
                        'pred_distance': to_np(m(output.data[:2,5,:,:])),
                        'pred_border': to_np(m(output.data[:2,6,:,:])),                        
                        'pred_energy': energy_levels[:2,:,:], 
                        'pred_wt': y_preds_wt[:2,:,:],
                        'pred_wt_seed': y_preds_wt_seed[:2,:,:,:],
                    }
                    for tag, images in info.items():
                        logger.image_summary(tag, images, valid_minib_counter)
                elif args.channels == 8:
                    info = {
                        'images': to_np(input[:2,:,:,:]),
                        'gt_mask': to_np(target[:2,0,:,:]),
                        'gt_mask1': to_np(target[:2,1,:,:]),
                        'gt_mask2': to_np(target[:2,2,:,:]),
                        'gt_mask3': to_np(target[:2,3,:,:]), 
                        'gt_mask0': to_np(target[:2,4,:,:]),
                        'gt_border': to_np(target[:2,5,:,:]),   
                        'gt_vectors': to_np(target[:2,6,:,:]+target[:2,7,:,:]), # simple hack - just sum the vectors
                        'pred_mask': to_np(m(output.data[:2,0,:,:])),
                        'pred_mask1': to_np(m(output.data[:2,1,:,:])),
                        'pred_mask2': to_np(m(output.data[:2,2,:,:])),
                        'pred_mask3': to_np(m(output.data[:2,3,:,:])),
                        'pred_mask0': to_np(m(output.data[:2,4,:,:])),
                        'pred_border': to_np(m(output.data[:2,5,:,:])),
                        'pred_vectors': to_np(output.data[:2,6,:,:]+output.data[:2,7,:,:]),                         
                        'pred_energy': energy_levels[:2,:,:], 
                        'pred_wt': y_preds_wt[:2,:,:],
                        'pred_wt_seed': y_preds_wt_seed[:2,:,:,:],
                    }
                    for tag, images in info.items():
                        logger.image_summary(tag, images, valid_minib_counter)                          
                        

                        
        # calcuale f1 scores only on inner cell masks
        # weird pytorch numerical issue when converting to float
        target_f1 = (target_var.data[:,0:1,:,:]>args.ths)*1        
        f1_scores_batch = batch_f1_score(output = m(output.data[:,0:1,:,:]),
                                   target = target_f1,
                                   threshold=args.ths)

        # measure accuracy and record loss
        losses.update(loss.data[0], input.size(0))
        f1_scores.update(f1_scores_batch, input.size(0))
        map_scores_wt.update(averaged_maps_wt, input.size(0))  
        map_scores_wt_seed.update(averaged_maps_wt_seed, input.size(0)) 

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #============ TensorBoard logging ============#
        # Log the scalar values        
        if args.tensorboard:
            info = {
                'valid_loss': losses.val,
                'f1_score_val': f1_scores.val, 
                'map_wt': averaged_maps_wt,
                'map_wt_seed': averaged_maps_wt_seed,
            }
            for tag, value in info.items():
                logger.scalar_summary(tag, value, valid_minib_counter)            
        
        valid_minib_counter += 1
        
        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time  {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss  {loss.val:.4f} ({loss.avg:.4f})\t'
                  'F1    {f1_scores.val:.4f} ({f1_scores.avg:.4f})\t'
                  'MAP1  {map_scores_wt.val:.4f} ({map_scores_wt.avg:.4f})\t'
                  'MAP2  {map_scores_wt_seed.val:.4f} ({map_scores_wt_seed.avg:.4f})\t'.format(
                   i, len(val_loader), batch_time=batch_time,
                      loss=losses,
                      f1_scores=f1_scores,
                      map_scores_wt=map_scores_wt,map_scores_wt_seed=map_scores_wt_seed))

    print(' * Avg Val  Loss {loss.avg:.4f}'.format(loss=losses))
    print(' * Avg F1   Score {f1_scores.avg:.4f}'.format(f1_scores=f1_scores))
    print(' * Avg MAP1 Score {map_scores_wt.avg:.4f}'.format(map_scores_wt=map_scores_wt)) 
    print(' * Avg MAP2 Score {map_scores_wt_seed.avg:.4f}'.format(map_scores_wt_seed=map_scores_wt_seed)) 

    return losses.avg, f1_scores.avg, map_scores_wt.avg,map_scores_wt_seed.avg