def get_sample(data):
    gdata = gBatch().from_data_list([data['points']])
    gdata = gdata.to(DEVICE)
    gdata.batch = gdata.bvec.clone()
    del gdata.bvec
    gdata['lengths'] = gdata['lengths'][0].item()

    return gdata
Exemple #2
0
def get_gbatch_sample(sample, sample_size, same_size, return_name=False):
    data_list = []
    name_list = []
    ori_batch = []
    for i, d in enumerate(sample):
        if 'bvec' in d['points'].keys:
            d['points'].bvec += sample_size * i
        data_list.append(d['points'])
        name_list.append(d['name'])
        ori_batch.append([i] * sample_size)
    points = gBatch().from_data_list(data_list)
    points.ori_batch = torch.tensor(ori_batch).flatten().long()
    if 'bvec' in points.keys:
        #points.batch = points.bvec.copy()
        points.batch = points.bvec.clone()
        del points.bvec
    if same_size:
        points['lengths'] = points['lengths'][0].item()

    if return_name:
        return points, name_list
    return points
Exemple #3
0
def val_iter(cfg, val_dataloader, classifier, writer, epoch, cluster_loss_fn,
             best_epoch, best_pred, logdir):

    num_classes = int(cfg['n_classes'])
    #batch_size = int(cfg['batch_size'])
    batch_size = 1
    n_epochs = int(cfg['n_epochs'])
    sample_size = int(cfg['fixed_size'])
    input_size = int(cfg['data_dim'])
    num_batch = cfg['num_batch']
    alfa = 0
    ep_loss = 0.

    classifier.eval()

    with torch.no_grad():
        pred_buffer = {}
        sm_buffer = {}
        sm2_buffer = {}
        gf_buffer = {}
        print('\n\n')
        mean_val_acc = torch.tensor([])
        mean_val_iou = torch.tensor([])
        mean_val_prec = torch.tensor([])
        mean_val_recall = torch.tensor([])
        mean_val_iou_c = torch.tensor([])

        for j, data in enumerate(val_dataloader):
            if 'graph' not in cfg['dataset']:
                points = data['points']
                target = data['gt']
                points, target = Variable(points), Variable(target)
                points, target = points.cuda(), target.cuda()
            else:
                data_list = []
                name_list = []
                for i, d in enumerate(data):
                    if 'bvec' in d['points'].keys:
                        d['points'].bvec += sample_size * i
                    data_list.append(d['points'])
                    name_list.append(d['name'])
                points = gBatch().from_data_list(data_list)
                if 'bvec' in points.keys:
                    points.batch = points.bvec.clone()
                    del points.bvec
                target = points['y']
                if cfg['same_size']:
                    points['lengths'] = points['lengths'][0].item()
                data = {'points': points, 'gt': target, 'name': name_list}
                points, target = points.to('cuda'), target.to('cuda')

            sample_name = data['name'][0]

            logits = classifier(points)

            if len(cfg['loss']) == 2:
                if epoch <= int(cfg['switch_loss_epoch']):
                    loss_type = cfg['loss'][0]
                else:
                    loss_type = cfg['loss'][1]
            else:
                loss_type = cfg['loss'][0]

            if loss_type == 'nll':
                pred = F.log_softmax(logits, dim=-1)
                pred = pred.view(-1, num_classes)
                probas = torch.exp(pred.data)
                pred_choice = pred.data.max(1)[1].int()
                if cfg['nll_w']:
                    ce_w = torch.tensor([1.5e-2] + [1.] *
                                        (num_classes - 1)).cuda()
                else:
                    ce_w = torch.tensor([1.] * num_classes).cuda()
                #print(pred.shape, target.shape)
                loss_seg = F.nll_loss(pred, target.long(), weight=ce_w)
            elif loss_type == 'LLh':
                pred_choice = (logits.data > 0).int()
                loss_seg = L.lovasz_hinge(logits.view(batch_size, sample_size,
                                                      1),
                                          target.view(batch_size, sample_size,
                                                      1),
                                          per_image=False)
            elif loss_type == 'LLm':
                pred = F.softmax(logits, dim=-1)
                probas = pred.data
                pred_choice = pred.data.max(1)[1].int()
                loss_seg = L.lovasz_softmax_flat(
                    pred,
                    target,
                    op=cfg['llm_op'],
                    only_present=cfg['multi_category'])

            loss = loss_seg

            ep_loss += loss
            print('val max class pred ', pred_choice.max().item())
            print('val min class pred ', pred_choice.min().item())
            print('# class pred ', len(np.unique(pred_choice.cpu().numpy())))
            correct = pred_choice.eq(target.data.int()).cpu().sum()
            acc = correct.item() / float(target.size(0))

            tp = torch.mul(pred_choice.data,
                           target.data.int()).cpu().sum().item() + 0.00001
            fp = pred_choice.gt(target.data.int()).cpu().sum().item()
            fn = pred_choice.lt(target.data.int()).cpu().sum().item()
            tn = correct.item() - tp
            iou = torch.tensor([float(tp) / (tp + fp + fn)])
            prec = torch.tensor([float(tp) / (tp + fp)])
            recall = torch.tensor([float(tp) / (tp + fn)])

            print('VALIDATION [%d: %d/%d] val loss: %f acc: %f iou: %f' %
                  (epoch, j, len(val_dataloader), loss, acc, iou))

            mean_val_prec = torch.cat((mean_val_prec, prec), 0)
            mean_val_recall = torch.cat((mean_val_recall, recall), 0)
            mean_val_iou = torch.cat((mean_val_iou, iou), 0)
            mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0)

            if cfg['save_pred']:
                sl_idx = np.where(
                    pred_choice.data.cpu().view(-1).numpy() == 1)[0]
                pred_buffer[sample_name] = sl_idx.tolist()

        macro_iou = torch.mean(mean_val_iou)
        macro_prec = torch.mean(mean_val_prec)
        macro_recall = torch.mean(mean_val_recall)
        macro_iou_c = torch.mean(mean_val_iou_c)

        epoch_iou = macro_iou.item()

        writer.add_scalar('val/epoch_acc',
                          torch.mean(mean_val_acc).item(), epoch)
        writer.add_scalar('val/epoch_iou', epoch_iou, epoch)
        writer.add_scalar('val/epoch_prec', macro_prec.item(), epoch)
        writer.add_scalar('val/epoch_recall', macro_recall.item(), epoch)
        writer.add_scalar('val/epoch_iou_c', macro_iou_c.item(), epoch)
        writer.add_scalar('val/loss', ep_loss / j, epoch)
        print('VALIDATION ACCURACY: %f' % torch.mean(mean_val_acc).item())
        print('VALIDATION IOU: %f' % epoch_iou)
        print('VALIDATION IOUC: %f' % macro_iou_c.item())
        print('\n\n')

        if epoch_iou > best_pred:
            best_pred = epoch_iou
            best_epoch = epoch

            if cfg['save_model']:
                modeldir = os.path.join(logdir, cfg['model_dir'])
                if not os.path.exists(modeldir):
                    os.makedirs(modeldir)
                else:
                    os.system('rm %s/best_model*.pth' % modeldir)
                torch.save(
                    classifier.state_dict(), '%s/best_model_iou-%f_ep-%d.pth' %
                    (modeldir, best_pred, epoch))
        return best_epoch, best_pred, ep_loss
Exemple #4
0
def train_iter(cfg, dataloader, classifier, optimizer, writer, epoch, n_iter,
               cluster_loss_fn):

    num_classes = int(cfg['n_classes'])
    batch_size = int(cfg['batch_size'])
    n_epochs = int(cfg['n_epochs'])
    sample_size = int(cfg['fixed_size'])
    input_size = int(cfg['data_dim'])
    num_batch = cfg['num_batch']
    alfa = 0

    ep_loss = 0.
    ep_seg_loss = 0.
    ep_cluster_loss = 0.
    mean_acc = torch.tensor([])
    mean_iou = torch.tensor([])
    mean_prec = torch.tensor([])
    mean_recall = torch.tensor([])

    ### state that the model will run in train mode
    classifier.train()

    #d_list=[]
    #for dat in dataloader:
    #for d in dat:
    #d_list.append(d)
    #points = gBatch().from_data_list(d_list)
    #target = points['y']
    #name = dataset['name']
    #points, target = points.to('cuda'), target.to('cuda')

    for i_batch, sample_batched in enumerate(dataloader):

        ### get batch
        if 'graph' not in cfg['dataset']:
            points = sample_batched['points']
            target = sample_batched['gt']
            #if cfg['model'] == 'pointnet_cls':
            #points = points.view(batch_size*sample_size, -1, input_size)
            #target = target.view(batch_size*sample_size, -1)

            #batch_size = batch_size*sample_size
            #sample_size = points.shape[1]
            points, target = Variable(points), Variable(target)
            points, target = points.cuda(), target.cuda()

        else:
            data_list = []
            name_list = []
            for i, d in enumerate(sample_batched):
                if 'bvec' in d['points'].keys:
                    d['points'].bvec += sample_size * i
                data_list.append(d['points'])
                name_list.append(d['name'])
            points = gBatch().from_data_list(data_list)
            if 'bvec' in points.keys:
                #points.batch = points.bvec.copy()
                points.batch = points.bvec.clone()
                del points.bvec
            #if 'bslices' in points.keys():
            #    points.__slices__ = torch.cum(
            target = points['y']
            if cfg['same_size']:
                points['lengths'] = points['lengths'][0].item()
            sample_batched = {
                'points': points,
                'gt': target,
                'name': name_list
            }
            #print('points:',points)

            #if (epoch != 0) and (epoch % 20 == 0):
            #    assert(len(dataloader.dataset) % int(cfg['fold_size']) == 0)
            #    folds = len(dataloader.dataset)/int(cfg['fold_size'])
            #    n_fold = (dataloader.dataset.n_fold + 1) % folds
            #    if n_fold != dataloader.dataset.n_fold:
            #        dataloader.dataset.n_fold = n_fold
            #        dataloader.dataset.load_fold()
            points, target = points.to('cuda'), target.to('cuda')
        #print(len(points.lengths),target.shape)

        ### initialize gradients
        #if not cfg['accumulation_interval'] or i_batch == 0:
        optimizer.zero_grad()

        ### forward
        logits = classifier(points)
        ### minimize the loss
        if len(cfg['loss']) == 2:
            if epoch <= int(cfg['switch_loss_epoch']):
                loss_type = cfg['loss'][0]
            else:
                loss_type = cfg['loss'][1]
        else:
            loss_type = cfg['loss'][0]

        if loss_type == 'nll':
            pred = F.log_softmax(logits, dim=-1)
            pred = pred.view(-1, num_classes)
            pred_choice = pred.data.max(1)[1].int()

            if cfg['nll_w']:
                ce_w = torch.tensor([1.5e-2] + [1.] * (num_classes - 1)).cuda()
            else:
                ce_w = torch.tensor([1.] * num_classes).cuda()
            #print(pred.shape)
            loss = F.nll_loss(pred, target.long(), weight=ce_w)
        elif loss_type == 'LLh':
            pred_choice = (logits.data > 0).int()
            loss = L.lovasz_hinge(logits.view(batch_size, sample_size, 1),
                                  target.view(batch_size, sample_size, 1),
                                  per_image=False)
        elif loss_type == 'LLm':
            pred = F.softmax(logits, dim=-1)
            pred_choice = pred.data.max(1)[1].int()
            loss = L.lovasz_softmax_flat(pred,
                                         target,
                                         op=cfg['llm_op'],
                                         only_present=cfg['multi_category'])

        ep_loss += loss
        if cfg['print_bwgraph']:
            #with torch.onnx.set_training(classifier, False):
            #    trace, _ = torch.jit.get_trace_graph(classifier, args=(points.transpose(2,1),))
            #g = make_dot_from_trace(trace)
            from torchviz import make_dot, make_dot_from_trace
            g = make_dot(loss, params=dict(classifier.named_parameters()))
            #   g = make_dot(loss,
            #                           params=None)
            g.view('pointnet_mgf')
            print('classifier parameters: %d' %
                  int(count_parameters(classifier)))
            os.system('rm -r runs/%s' % writer.logdir.split('/', 1)[1])
            os.system('rm -r tb_logs/%s' % writer.logdir.split('/', 1)[1])
            import sys
            sys.exit()
        #print('memory allocated in MB: ', torch.cuda.memory_allocated()/2**20)
        #import sys; sys.exit()
        loss.backward()

        #if int(cfg['accumulation_interval']) % (i_batch+1) == 0:
        optimizer.step()
        #elif not cfg['accumulation_interval']:
        #    optimizer.step()

        ### compute performance
        correct = pred_choice.eq(target.data.int()).sum()
        acc = correct.item() / float(target.size(0))

        tp = torch.mul(pred_choice.data,
                       target.data.int()).sum().item() + 0.00001
        fp = pred_choice.gt(target.data.int()).sum().item()
        fn = pred_choice.lt(target.data.int()).sum().item()
        tn = correct.item() - tp
        iou = float(tp) / (tp + fp + fn)
        prec = float(tp) / (tp + fp)
        recall = float(tp) / (tp + fn)

        print('[%d: %d/%d] train loss: %f acc: %f iou: %f' \
              % (epoch, i_batch, num_batch, loss.item(), acc, iou))

        mean_prec = torch.cat((mean_prec, torch.tensor([prec])), 0)
        mean_recall = torch.cat((mean_recall, torch.tensor([recall])), 0)
        mean_acc = torch.cat((mean_acc, torch.tensor([acc])), 0)
        mean_iou = torch.cat((mean_iou, torch.tensor([iou])), 0)
        n_iter += 1

    writer.add_scalar('train/epoch_loss', ep_loss / (i_batch + 1), epoch)

    return mean_acc, mean_prec, mean_iou, mean_recall, ep_loss / (i_batch +
                                                                  1), n_iter
Exemple #5
0
def test(cfg):
    num_classes = int(cfg['n_classes'])
    sample_size = int(cfg['fixed_size'])
    cfg['loss'] = cfg['loss'].split(' ')
    batch_size = 1
    cfg['batch_size'] = batch_size
    epoch = eval(cfg['n_epochs'])
    #n_gf = int(cfg['num_gf'])
    input_size = int(cfg['data_dim'])

    trans_val = []
    if cfg['rnd_sampling']:
        trans_val.append(ds.TestSampling(sample_size))
    #if cfg['standardization']:
    #    trans_val.append(ds.SampleStandardization())

    if cfg['dataset'] == 'left_ifof_ss_sl':
        dataset = ds.LeftIFOFSupersetDataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_val),
            uniform_size=True,
            train=False,
            split_obj=True,
            with_gt=cfg['with_gt'])
    elif cfg['dataset'] == 'hcp20_graph':
        dataset = ds.HCP20Dataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_val),
            with_gt=cfg['with_gt'],
            #distance=T.Distance(norm=True,cat=False),
            return_edges=True,
            split_obj=True,
            train=False,
            load_one_full_subj=False,
            standardize=cfg['standardization'])
    elif cfg['dataset'] == 'left_ifof_ss_sl_graph':
        dataset = ds.LeftIFOFSupersetGraphDataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_val),
            train=False,
            split_obj=True,
            with_gt=cfg['with_gt'])
    elif cfg['dataset'] == 'left_ifof_emb':
        dataset = ds.EmbDataset(cfg['sub_list_test'],
                                cfg['emb_dataset_dir'],
                                cfg['gt_dataset_dir'],
                                transform=transforms.Compose(trans_val),
                                load_all=cfg['load_all_once'],
                                precompute_graph=cfg['precompute_graph'],
                                k_graph=int(cfg['knngraph']))
    elif cfg['dataset'] == 'psb_airplane':
        dataset = ds.PsbAirplaneDataset(cfg['dataset_dir'], train=False)
    elif cfg['dataset'] == 'shapes':
        dataset = ds.ShapesDataset(cfg['dataset_dir'],
                                   train=False,
                                   multi_cat=cfg['multi_category'])
    elif cfg['dataset'] == 'shapenet':
        dataset = ds.ShapeNetCore(cfg['dataset_dir'],
                                  train=False,
                                  multi_cat=cfg['multi_category'])
    elif cfg['dataset'] == 'modelnet':
        dataset = ds.ModelNetDataset(cfg['dataset_dir'],
                                     split=cfg['mn40_split'],
                                     fold_size=int(cfg['mn40_fold_size']),
                                     load_all=cfg['load_all_once'])
    elif cfg['dataset'] == 'scanobj':
        dataset = ds.ScanObjNNDataset(cfg['dataset_dir'],
                                      run='test',
                                      variant=cfg['scanobj_variant'],
                                      background=cfg['scanobj_bg'],
                                      load_all=cfg['load_all_once'])
    else:
        dataset = ds.DRLeftIFOFSupersetDataset(
            cfg['sub_list_test'],
            cfg['val_dataset_dir'],
            transform=transforms.Compose(trans_val),
            with_gt=cfg['with_gt'])

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)
    print("Validation dataset loaded, found %d samples" % (len(dataset)))

    for ext in range(100):
        logdir = '%s/test_%d' % (cfg['exp_path'], ext)
        if not os.path.exists(logdir):
            break
    writer = SummaryWriter(logdir)
    if cfg['weights_path'] == '':
        cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0]
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])
    elif 'ep-' in cfg['weights_path']:
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])

    tb_log_name = glob.glob('%s/events*' % writer.logdir)[0].rsplit('/', 1)[1]
    tb_log_dir = 'tb_logs/%s' % logdir.split('/', 1)[1]
    os.system('mkdir -p %s' % tb_log_dir)
    os.system('ln -sr %s/%s %s/%s ' %
              (writer.logdir, tb_log_name, tb_log_dir, tb_log_name))

    #### BUILD THE MODEL
    classifier = get_model(cfg)

    classifier.cuda()
    classifier.load_state_dict(torch.load(cfg['weights_path']))
    classifier.eval()

    with torch.no_grad():
        pred_buffer = {}
        sm_buffer = {}
        sm2_buffer = {}
        gf_buffer = {}
        emb_buffer = {}
        print('\n\n')
        mean_val_acc = torch.tensor([])
        mean_val_iou = torch.tensor([])
        mean_val_prec = torch.tensor([])
        mean_val_recall = torch.tensor([])

        if 'split_obj' in dir(dataset) and dataset.split_obj:
            split_obj = True
        else:
            split_obj = False
            dataset.transform = []

        if split_obj:
            consumed = False
        else:
            consumed = True
        j = 0
        visualized = 0
        new_obj_read = True
        sls_count = 1
        while j < len(dataset):
            #while sls_count <= len(dataset):
            data = dataset[j]

            if split_obj:
                if new_obj_read:
                    obj_pred_choice = torch.zeros(data['obj_full_size'],
                                                  dtype=torch.int).cuda()
                    obj_target = torch.zeros(data['obj_full_size'],
                                             dtype=torch.int).cuda()
                    new_obj_read = False
                    #if cfg['save_embedding']:
                    #obj_embedding = torch.empty((data['obj_full_size'], int(cfg['embedding_size']))).cuda()

                if len(dataset.remaining[j]) == 0:
                    consumed = True

            sample_name = data['name'] if type(
                data['name']) == str else data['name'][0]

            #print(points)
            #if len(points.shape()) == 2:
            #points = points.unsqueeze(0)
            if 'graph' not in cfg['dataset']:
                points = data['points']
                if cfg['with_gt']:
                    target = data['gt']
                    target = target.to('cuda')
                    target = target.view(-1, 1)[:, 0]
            else:
                #print(data)
                points = gBatch().from_data_list([data['points']])
                #points = data['points']
                if 'bvec' in points.keys:
                    points.batch = points.bvec.clone()
                    del points.bvec
                # points.ori_batch = torch.zeros(points.x.size(0)).long()
                if cfg['with_gt']:
                    target = points['y']
                    target = target.to('cuda')
                    target = target.view(-1, 1)[:, 0]
                if cfg['same_size']:
                    points['lengths'] = points['lengths'][0].item()
            #if cfg['model'] == 'pointnet_cls':
            #points = points.view(len(data['obj_idxs']), -1, input_size)
            points = points.to('cuda')
            #print('streamline number:',sls_count)
            #sls_count+=1
            #print('lengths:',points['lengths'].item())
            ### add one-hot labels if multi-category task
            #new_k = points['lengths'].item()*(5/16)
            #print('new k:',new_k,'rounded k:',int(round(new_k)))
            #classifier.conv2.k = int(round(new_k))

            #if cfg['multi_loss']:
            #logits, gf = classifier(points)
            #else:
            logits = classifier(points)
            logits = logits.view(-1, num_classes)

            if len(cfg['loss']) == 2:
                if epoch <= int(cfg['switch_loss_epoch']):
                    loss_type = cfg['loss'][0]
                else:
                    loss_type = cfg['loss'][1]
            else:
                loss_type = cfg['loss'][0]

                if loss_type == 'nll':
                    pred = F.log_softmax(logits, dim=-1)
                    pred = pred.view(-1, num_classes)
                    probas = torch.exp(pred.data)
                    pred_choice = pred.data.max(1)[1].int()
                    if cfg['with_gt']:
                        loss_seg = F.nll_loss(pred, target.long())
                elif loss_type == 'LLh':
                    pred_choice = (logits.data > 0).int()
                    if cfg['with_gt']:
                        loss_seg = L.lovasz_hinge(
                            logits.view(batch_size, sample_size, 1),
                            target.view(batch_size, sample_size, 1),
                            per_image=False)
                    #loss = L.lovasz_hinge_flat(pred.view(-1), target.view(-1))
                elif loss_type == 'LLm':
                    pred = F.softmax(logits, dim=-1)
                    probas = pred.data
                    pred_choice = pred.data.max(1)[1].int()
                    if cfg['with_gt']:
                        loss = L.lovasz_softmax_flat(
                            pred,
                            target,
                            op=cfg['llm_op'],
                            only_present=cfg['multi_category'])
            #print('pred:',pred)
            #print('pred shape:',pred.shape)
            #print('pred choice:',pred_choice)
            #print('pred choice shape:',pred_choice.shape)
            #if visualized < int(cfg['viz_clusters']):
            #    visualized += 1
            #    colors = torch.from_numpy(get_spaced_colors(n_gf))
            #    sm_out = classifier.feat.mf.softmax_out[0,:,:].max(1)[1].squeeze().int()
            #    writer.add_mesh('latent clustering', points, colors[sm_out.tolist()].unsqueeze(0))
            #    if 'bg' in data.keys():
            #        bg_msk = data['bg']*-1
            #        writer.add_mesh('bg_mask', points, colors[bg_msk.tolist()].unsqueeze(0))

            if split_obj:
                obj_pred_choice[data['obj_idxs']] = pred_choice
                obj_target[data['obj_idxs']] = target.int()
                #if cfg['save_embedding']:
                #    obj_embedding[data['obj_idxs']] = classifier.embedding.squeeze()
            else:
                obj_data = points
                obj_pred_choice = pred_choice
                obj_target = target
                if cfg['save_embedding']:
                    obj_embedding = classifier.embedding.squeeze()

            if cfg['with_gt'] and consumed:
                #if cfg['multi_loss']:
                #    loss_cluster = cluster_loss_fn(gf.squeeze(3))
                #    loss = loss_seg + alfa * loss_cluster

                #pred_choice = torch.sigmoid(pred.view(-1,1)).data.round().type_as(target.data)
                #print('points:',points['streamlines'])
                #print('points shape:',points['streamlines'].shape)
                #print('streamlines:',
                data_dir = cfg['dataset_dir']
                #streamlines, head, leng, idxs = load_streamlines(data['dir']+'/'+data['name']+'.trk')
                #print('tract:',len(streamlines))
                #print('pred:',obj_pred_choice)
                #print('taget:',obj_target)
                #print('pred shape:',obj_pred_choice.shape)
                #print('target shape:',obj_target.shape)
                print('val max class red ', obj_pred_choice.max().item())
                print('val min class pred ', obj_pred_choice.min().item())
                y_pred = obj_pred_choice.cpu().numpy()
                np.save(data['dir'] + '/y_pred_sDEC_k5_16pts_nodropout',
                        y_pred)
                y_test = obj_target.cpu().numpy()
                np.save(data['dir'] + '/y_test_sDEC_k5_16pts_nodropout',
                        y_test)
                #np.save(data['dir']+'/streamlines_lstm_GIN',streamlines)
                correct = obj_pred_choice.eq(obj_target.data.int()).cpu().sum()
                acc = correct.item() / float(obj_target.size(0))

                if num_classes > 2:
                    iou, prec, recall = L.iou_multi(
                        obj_pred_choice.data.int().cpu().numpy(),
                        obj_target.data.int().cpu().numpy(),
                        num_classes,
                        multi_cat=cfg['multi_category'])
                    assert (np.isnan(iou).sum() == 0)
                    if cfg['multi_category']:
                        s, n_parts = data['gt_offset']
                        e = s + n_parts
                        iou[0, :s], prec[0, :s], recall[0, :s] = 0., 0., 0.
                        iou[0, e:], prec[0, e:], recall[0, e:] = 0., 0., 0.
                        iou = torch.from_numpy(iou).float()
                        prec = torch.from_numpy(prec).float()
                        recall = torch.from_numpy(recall).float()
                        category = data['category'].squeeze().nonzero().float()
                        iou = torch.cat([iou, category], 1)
                    else:
                        iou = torch.tensor([iou.mean()])
                        prec = torch.tensor([prec.mean()])
                        recall = torch.tensor([recall.mean()])
                    assert (torch.isnan(iou).sum() == 0)

                else:
                    tp = torch.mul(
                        obj_pred_choice.data,
                        obj_target.data.int()).cpu().sum().item() + 0.00001
                    fp = obj_pred_choice.gt(
                        obj_target.data.int()).cpu().sum().item()
                    fn = obj_pred_choice.lt(
                        obj_target.data.int()).cpu().sum().item()
                    tn = correct.item() - tp
                    iou = torch.tensor([float(tp) / (tp + fp + fn)])
                    prec = torch.tensor([float(tp) / (tp + fp)])
                    recall = torch.tensor([float(tp) / (tp + fn)])

                mean_val_prec = torch.cat((mean_val_prec, prec), 0)
                mean_val_recall = torch.cat((mean_val_recall, recall), 0)
                mean_val_iou = torch.cat((mean_val_iou, iou), 0)
                mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])),
                                         0)
                print('VALIDATION [%d: %d/%d] val accuracy: %f' \
                        % (epoch, j, len(dataset), acc))

            if cfg['save_pred'] and consumed:
                print('buffering prediction %s' % sample_name)
                if num_classes > 2:
                    for cl in range(num_classes):
                        sl_idx = np.where(
                            obj_pred.data.cpu().view(-1).numpy() == cl)[0]
                        if cl == 0:
                            pred_buffer[sample_name] = []
                        pred_buffer[sample_name].append(sl_idx.tolist())
                else:
                    sl_idx = np.where(
                        obj_pred.data.cpu().view(-1).numpy() == 1)[0]
                    pred_buffer[sample_name] = sl_idx.tolist()
            #if cfg['save_softmax_out']:
            #    if cfg['model'] in 'pointnet_mgfml':
            #        if sample_name not in sm_buffer.keys():
            #            sm_buffer[sample_name] = []
            #        if classifier.feat.multi_feat > 1:
            #            sm_buffer[sample_name].append(
            #                classifier.feat.mf.softmax_out.cpu().numpy())
            #    if cfg['model'] == 'pointnet_mgfml':
            #        for l in classifier.layers:
            #            sm_buffer[sample_name].append(
            #                    l.mf.softmax_out.cpu().numpy())
            #    sm2_buffer[sample_name] = probas.cpu().numpy()
            #if cfg['save_gf']:
            #   gf_buffer[sample_name] = np.unique(
            #           classifier.feat.globalfeat.data.cpu().squeeze().numpy(), axis = 0)
            #gf_buffer[sample_name] = classifier.globalfeat
            #if cfg['save_embedding'] and consumed:
            #    emb_buffer[sample_name] = obj_embedding

            if consumed:
                print(j)
                j += 1
                if split_obj:
                    consumed = False
                    new_obj_read = True

        macro_iou = torch.mean(mean_val_iou)
        macro_prec = torch.mean(mean_val_prec)
        macro_recall = torch.mean(mean_val_recall)

        epoch_iou = macro_iou.item()

    if cfg['save_pred']:
        #os.system('rm -r %s/predictions_test*' % writer.logdir)
        pred_dir = writer.logdir + '/predictions_test_%d' % epoch
        if not os.path.exists(pred_dir):
            os.makedirs(pred_dir)
        print('saving files')
        for filename, value in pred_buffer.items():
            with open(os.path.join(pred_dir, filename) + '.pkl', 'wb') as f:
                pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

    #if cfg['save_softmax_out']:
    #    os.system('rm -r %s/sm_out_test*' % writer.logdir)
    #    sm_dir = writer.logdir + '/sm_out_test_%d' % epoch
    #    if not os.path.exists(sm_dir):
    #        os.makedirs(sm_dir)
    #    for filename, value in sm_buffer.iteritems():
    #        with open(os.path.join(sm_dir, filename) + '_sm_1.pkl', 'wb') as f:
    #            pickle.dump(
    #                value, f, protocol=pickle.HIGHEST_PROTOCOL)
    #    for filename, value in sm2_buffer.iteritems():
    #        with open(os.path.join(sm_dir, filename) + '_sm_2.pkl', 'wb') as f:
    #            pickle.dump(
    #                value, f, protocol=pickle.HIGHEST_PROTOCOL)

    #if cfg['save_gf']:
    #os.system('rm -r %s/gf_test*' % writer.logdir)
    #    gf_dir = writer.logdir + '/gf_test_%d' % epoch
    #    if not os.path.exists(gf_dir):
    #        os.makedirs(gf_dir)
    #    i = 0
    #    for filename, value in gf_buffer.items():
    #        if i == 3:
    #            break
    #        with open(os.path.join(gf_dir, filename) + '.pkl', 'wb') as f:
    #            pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

    #if cfg['save_embedding']:
    #    print('saving embedding')
    #    emb_dir = writer.logdir + '/embedding_test_%d' % epoch
    #    if not os.path.exists(emb_dir):
    #        os.makedirs(emb_dir)
    #    for filename, value in emb_buffer.iteritems():
    #        np.save(os.path.join(emb_dir, filename), value.cpu().numpy())

    if cfg['with_gt']:
        print('TEST ACCURACY: %f' % torch.mean(mean_val_acc).item())
        print('TEST PRECISION: %f' % macro_prec.item())
        print('TEST RECALL: %f' % macro_recall.item())
        print('TEST IOU: %f' % macro_iou.item())
        mean_val_dsc = mean_val_prec * mean_val_recall * 2 / (mean_val_prec +
                                                              mean_val_recall)
        final_scores_file = writer.logdir + '/final_scores_test_%d.txt' % epoch
        scores_file = writer.logdir + '/scores_test_%d.txt' % epoch
        print('saving scores')
        with open(scores_file, 'w') as f:
            f.write('acc\n')
            f.writelines('%f\n' % v for v in mean_val_acc.tolist())
            f.write('prec\n')
            f.writelines('%f\n' % v for v in mean_val_prec.tolist())
            f.write('recall\n')
            f.writelines('%f\n' % v for v in mean_val_recall.tolist())
            f.write('dsc\n')
            f.writelines('%f\n' % v for v in mean_val_dsc.tolist())
            f.write('iou\n')
            f.writelines('%f\n' % v for v in mean_val_iou.tolist())
        with open(final_scores_file, 'w') as f:
            f.write('acc\n')
            f.write('%f\n' % mean_val_acc.mean())
            f.write('%f\n' % mean_val_acc.std())
            f.write('prec\n')
            f.write('%f\n' % mean_val_prec.mean())
            f.write('%f\n' % mean_val_prec.std())
            f.write('recall\n')
            f.write('%f\n' % mean_val_recall.mean())
            f.write('%f\n' % mean_val_recall.std())
            f.write('dsc\n')
            f.write('%f\n' % mean_val_dsc.mean())
            f.write('%f\n' % mean_val_dsc.std())
            f.write('iou\n')
            f.write('%f\n' % mean_val_iou.mean())
            f.write('%f\n' % mean_val_iou.std())

    print('\n\n')
Exemple #6
0
def test(cfg):
    num_classes = int(cfg['n_classes'])
    sample_size = int(cfg['fixed_size'])
    cfg['loss'] = cfg['loss'].split(' ')
    batch_size = 1
    cfg['batch_size'] = batch_size
    epoch = eval(str(cfg['n_epochs']))
    #n_gf = int(cfg['num_gf'])
    input_size = int(cfg['data_dim'])

    trans_val = []
    if cfg['rnd_sampling']:
        trans_val.append(TestSampling(sample_size))
    if cfg['standardization']:
        trans_val.append(SampleStandardization())

    if cfg['dataset'] == 'hcp20_graph':
        dataset = ds.HCP20Dataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_val),
            with_gt=cfg['with_gt'],
            #distance=T.Distance(norm=True,cat=False),
            return_edges=True,
            split_obj=True,
            train=False,
            load_one_full_subj=False,
            labels_dir=cfg['labels_dir'])

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)
    print("Validation dataset loaded, found %d samples" % (len(dataset)))

    for ext in range(100):
        logdir = '%s/test_%d' % (cfg['exp_path'], ext)
        if not os.path.exists(logdir):
            break
    writer = SummaryWriter(logdir)
    if cfg['weights_path'] == '':
        cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0]
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])
    elif 'ep-' in cfg['weights_path']:
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])

    tb_log_name = glob.glob('%s/events*' % writer.logdir)[0].rsplit('/', 1)[1]
    tb_log_dir = 'tb_logs/%s' % logdir.split('/', 1)[1]
    os.system('mkdir -p %s' % tb_log_dir)
    os.system('ln -sr %s/%s %s/%s ' %
              (writer.logdir, tb_log_name, tb_log_dir, tb_log_name))

    #### BUILD THE MODEL
    classifier = get_model(cfg)

    classifier.cuda()
    classifier.load_state_dict(torch.load(cfg['weights_path']))
    classifier.eval()

    with torch.no_grad():
        pred_buffer = {}
        sm_buffer = {}
        sm2_buffer = {}
        gf_buffer = {}
        emb_buffer = {}
        print('\n\n')
        #mean_val_acc = torch.tensor([])
        #mean_val_iou = torch.tensor([])
        #mean_val_prec = torch.tensor([])
        #mean_val_recall = torch.tensor([])
        mean_val_mse = torch.tensor([])
        mean_val_mae = torch.tensor([])
        mean_val_rho = torch.tensor([])

        if 'split_obj' in dir(dataset) and dataset.split_obj:
            split_obj = True
        else:
            split_obj = False
            dataset.transform = []

        if split_obj:
            consumed = False
        else:
            consumed = True
        j = 0
        visualized = 0
        new_obj_read = True
        sls_count = 1
        while j < len(dataset):
            #while sls_count <= len(dataset):
            data = dataset[j]

            if split_obj:
                if new_obj_read:
                    #obj_pred_choice = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda()
                    #obj_target = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda()
                    obj_pred_choice = torch.zeros(data['obj_full_size'],
                                                  dtype=torch.float32).cuda()
                    obj_target = torch.zeros(data['obj_full_size'],
                                             dtype=torch.float32).cuda()
                    new_obj_read = False

                if len(dataset.remaining[j]) == 0:
                    consumed = True

            sample_name = data['name'] if type(
                data['name']) == str else data['name'][0]
            print(sample_name)

            #print(points)
            #if len(points.shape()) == 2:
            #points = points.unsqueeze(0)
            #print(data)
            points = gBatch().from_data_list([data['points']])
            #points = data['points']
            if 'bvec' in points.keys:
                points.batch = points.bvec.clone()
                del points.bvec
            if cfg['with_gt']:
                target = points['y']
                target = target.to('cuda')
            if cfg['same_size']:
                points['lengths'] = points['lengths'][0].item()
            #if cfg['model'] == 'pointnet_cls':
            #points = points.view(len(data['obj_idxs']), -1, input_size)
            points = points.to('cuda')

            pred = classifier(points)

            #logits = classifier(points)
            #logits = logits.view(-1, num_classes)

            #pred = F.log_softmax(logits, dim=-1).view(-1, num_classes)
            #pred_choice = pred.data.max(1)[1].int()

            if split_obj:
                obj_pred_choice[data['obj_idxs']] = pred.view(-1)
                #obj_pred_choice[data['obj_idxs']] = pred_choice
                obj_target[data['obj_idxs']] = target.float()
                #print(obj_pred_choice)
                #print(obj_target)
                #obj_target[data['obj_idxs']] = target.int()
                #if cfg['save_embedding']:
                #    obj_embedding[data['obj_idxs']] = classifier.embedding.squeeze()
            else:
                obj_data = points
                obj_pred_choice = pred_choice
                obj_target = target
                if cfg['save_embedding']:
                    obj_embedding = classifier.embedding.squeeze()

            if cfg['with_gt'] and consumed:
                print('val max class pred ', obj_pred_choice.max().item())
                print('val min class pred ', obj_pred_choice.min().item())
                print('val max class target ', obj_target.max().item())
                print('val min class target ', obj_target.min().item())
                #obj_pred_choice = obj_pred_choice.view(-1,1)
                #obj_target = obj_target.view(-1,1)

                mae = torch.mean(
                    abs(obj_target.data.cpu() -
                        obj_pred_choice.data.cpu())).item()
                mse = torch.mean((obj_target.data.cpu() -
                                  obj_pred_choice.data.cpu())**2).item()
                rho, pval = spearmanr(obj_target.data.cpu().numpy(),
                                      obj_pred_choice.data.cpu().numpy())
                np.save(writer.logdir + '/predictions_' + sample_name + '.npy',
                        obj_pred_choice.data.cpu().numpy())
                #correct = obj_pred_choice.eq(obj_target.data.int()).cpu().sum()
                #acc = correct.item()/float(obj_target.size(0))
                #tp = torch.mul(obj_pred_choice.data, obj_target.data.int()).cpu().sum().item()+0.00001
                #fp = obj_pred_choice.gt(obj_target.data.int()).cpu().sum().item()
                #fn = obj_pred_choice.lt(obj_target.data.int()).cpu().sum().item()
                #tn = correct.item() - tp
                #iou = torch.tensor([float(tp)/(tp+fp+fn)])
                #prec = torch.tensor([float(tp)/(tp+fp)])
                #recall = torch.tensor([float(tp)/(tp+fn)])

                mean_val_mae = torch.cat((mean_val_mae, torch.tensor([mae])),
                                         0)
                mean_val_mse = torch.cat((mean_val_mse, torch.tensor([mse])),
                                         0)
                mean_val_rho = torch.cat((mean_val_rho, torch.tensor([rho])),
                                         0)
                #mean_val_prec = torch.cat((mean_val_prec, prec), 0)
                #mean_val_recall = torch.cat((mean_val_recall, recall), 0)
                #mean_val_iou = torch.cat((mean_val_iou, iou), 0)
                #mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0)
                print('VALIDATION [%d: %d/%d] val mse: %f val mae: %f val rho: %f' \
                        % (epoch, j, len(dataset), mse, mae, rho))

            if cfg['save_pred'] and consumed:
                print('buffering prediction %s' % sample_name)
                #sl_idx = np.where(obj_pred.data.cpu().view(-1).numpy() == 1)[0]
                #pred_buffer[sample_name] = sl_idx.tolist()

            if consumed:
                print(j)
                j += 1
                if split_obj:
                    consumed = False
                    new_obj_read = True

        #macro_iou = torch.mean(mean_val_iou)
        #macro_prec = torch.mean(mean_val_prec)
        #macro_recall = torch.mean(mean_val_recall)

        #epoch_iou = macro_iou.item()

    #if cfg['save_pred']:
    #os.system('rm -r %s/predictions_test*' % writer.logdir)
    #   pred_dir = writer.logdir + '/predictions_test_%d' % epoch
    #  if not os.path.exists(pred_dir):
    #     os.makedirs(pred_dir)
    #print('saving files')
    #for filename, value in pred_buffer.items():
    #    with open(os.path.join(pred_dir, filename) + '.pkl', 'wb') as f:
    #        pickle.dump(
    #            value, f, protocol=pickle.HIGHEST_PROTOCOL)

    if cfg['with_gt']:
        print('TEST MSE: %f' % torch.mean(mean_val_mse).item())
        print('TEST MAE: %f' % torch.mean(mean_val_mae).item())
        print('TEST RHO: %f' % torch.mean(mean_val_rho).item())
        #print('TEST ACCURACY: %f' % torch.mean(mean_val_acc).item())
        #print('TEST PRECISION: %f' % macro_prec.item())
        #print('TEST RECALL: %f' % macro_recall.item())
        #print('TEST IOU: %f' % macro_iou.item())
        #mean_val_dsc = mean_val_prec * mean_val_recall * 2 / (mean_val_prec + mean_val_recall)
        final_scores_file = writer.logdir + '/final_scores_test_%d.txt' % epoch
        scores_file = writer.logdir + '/scores_test_%d.txt' % epoch
        print('saving scores')
        with open(scores_file, 'w') as f:
            f.write('mse\n')
            f.writelines('%f\n' % v for v in mean_val_mse.tolist())
            f.write('mae\n')
            f.writelines('%f\n' % v for v in mean_val_mae.tolist())
            f.write('rho\n')
            f.writelines('%f\n' % v for v in mean_val_rho.tolist())
            #f.write('acc\n')
            #f.writelines('%f\n' % v for v in  mean_val_acc.tolist())
            #f.write('prec\n')
            #f.writelines('%f\n' % v for v in  mean_val_prec.tolist())
            #f.write('recall\n')
            #f.writelines('%f\n' % v for v in  mean_val_recall.tolist())
            #f.write('dsc\n')
            #f.writelines('%f\n' % v for v in  mean_val_dsc.tolist())
            #f.write('iou\n')
            #f.writelines('%f\n' % v for v in  mean_val_iou.tolist())
        with open(final_scores_file, 'w') as f:
            f.write('mse\n')
            f.write('%f\n' % mean_val_mse.mean())
            f.write('%f\n' % mean_val_mse.std())
            f.write('mae\n')
            f.write('%f\n' % mean_val_mae.mean())
            f.write('%f\n' % mean_val_mae.std())
            f.write('rho\n')
            f.write('%f\n' % mean_val_rho.mean())
            f.write('%f\n' % mean_val_rho.std())
            #f.write('acc\n')
            #f.write('%f\n' % mean_val_acc.mean())
            #f.write('%f\n' % mean_val_acc.std())
            #f.write('prec\n')
            #f.write('%f\n' % mean_val_prec.mean())
            #f.write('%f\n' % mean_val_prec.std())
            #f.write('recall\n')
            #f.write('%f\n' % mean_val_recall.mean())
            #f.write('%f\n' % mean_val_recall.std())
            #f.write('dsc\n')
            #f.write('%f\n' % mean_val_dsc.mean())
            #f.write('%f\n' % mean_val_dsc.std())
            #f.write('iou\n')
            #f.write('%f\n' % mean_val_iou.mean())
            #f.write('%f\n' % mean_val_iou.std())

    print('\n\n')
Exemple #7
0
def val_iter(cfg, val_dataloader, classifier, writer, epoch, best_epoch,
             best_pred, logdir):

    num_classes = int(cfg['n_classes'])
    batch_size = 1
    sample_size = int(cfg['fixed_size'])
    ep_loss = 0.

    classifier.eval()

    with torch.no_grad():
        print('\n\n')
        mean_val_acc = torch.tensor([])
        mean_val_iou = torch.tensor([])
        mean_val_prec = torch.tensor([])
        mean_val_recall = torch.tensor([])
        mean_val_iou_c = torch.tensor([])

        for j, data in enumerate(val_dataloader):
            data_list = []
            name_list = []
            for i, d in enumerate(data):
                if 'bvec' in d['points'].keys:
                    d['points'].bvec += sample_size * i
                data_list.append(d['points'])
                name_list.append(d['name'])
            points = gBatch().from_data_list(data_list)
            if 'bvec' in points.keys:
                points.batch = points.bvec.clone()
                del points.bvec
            target = points['y']
            if cfg['same_size']:
                points['lengths'] = points['lengths'][0].item()
            data = {'points': points, 'gt': target, 'name': name_list}
            points, target = points.to('cuda'), target.to('cuda')

            logits = classifier(points)

            pred = F.log_softmax(logits, dim=-1)
            pred = pred.view(-1, num_classes)
            pred_choice = pred.data.max(1)[1].int()
            loss = F.nll_loss(pred, target.long())

            ep_loss += loss
            print('val max class pred ', pred_choice.max().item())
            print('val min class pred ', pred_choice.min().item())
            print('# class pred ', len(np.unique(pred_choice.cpu().numpy())))
            correct = pred_choice.eq(target.data.int()).cpu().sum()
            acc = correct.item() / float(target.size(0))

            tp = torch.mul(pred_choice.data,
                           target.data.int()).cpu().sum().item() + 0.00001
            fp = pred_choice.gt(target.data.int()).cpu().sum().item()
            fn = pred_choice.lt(target.data.int()).cpu().sum().item()
            tn = correct.item() - tp
            iou = torch.tensor([float(tp) / (tp + fp + fn)])
            prec = torch.tensor([float(tp) / (tp + fp)])
            recall = torch.tensor([float(tp) / (tp + fn)])

            print('VALIDATION [%d: %d/%d] val loss: %f acc: %f iou: %f' %
                  (epoch, j, len(val_dataloader), loss, acc, iou))

            mean_val_prec = torch.cat((mean_val_prec, prec), 0)
            mean_val_recall = torch.cat((mean_val_recall, recall), 0)
            mean_val_iou = torch.cat((mean_val_iou, iou), 0)
            mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0)

        macro_iou = torch.mean(mean_val_iou)
        macro_prec = torch.mean(mean_val_prec)
        macro_recall = torch.mean(mean_val_recall)
        macro_iou_c = torch.mean(mean_val_iou_c)

        epoch_iou = macro_iou.item()

        writer.add_scalar('val/epoch_acc',
                          torch.mean(mean_val_acc).item(), epoch)
        writer.add_scalar('val/epoch_iou', epoch_iou, epoch)
        writer.add_scalar('val/epoch_prec', macro_prec.item(), epoch)
        writer.add_scalar('val/epoch_recall', macro_recall.item(), epoch)
        writer.add_scalar('val/epoch_iou_c', macro_iou_c.item(), epoch)
        writer.add_scalar('val/loss', ep_loss / j, epoch)
        print('VALIDATION ACCURACY: %f' % torch.mean(mean_val_acc).item())
        print('VALIDATION IOU: %f' % epoch_iou)
        print('VALIDATION IOUC: %f' % macro_iou_c.item())
        print('\n\n')

        return best_epoch, best_pred, ep_loss
Exemple #8
0
def train_iter(cfg, dataloader, classifier, optimizer, writer, epoch, n_iter,
               cluster_loss_fn):

    num_classes = int(cfg['n_classes'])
    batch_size = int(cfg['batch_size'])
    n_epochs = int(cfg['n_epochs'])
    sample_size = int(cfg['fixed_size'])
    input_size = int(cfg['data_dim'])
    num_batch = cfg['num_batch']
    alfa = 0

    ep_loss = 0.
    ep_seg_loss = 0.
    ep_cluster_loss = 0.
    mean_acc = torch.tensor([])
    mean_iou = torch.tensor([])
    mean_prec = torch.tensor([])
    mean_recall = torch.tensor([])

    ### state that the model will run in train mode
    classifier.train()

    #d_list=[]
    #for dat in dataloader:
    #for d in dat:
    #d_list.append(d)
    #points = gBatch().from_data_list(d_list)
    #target = points['y']
    #name = dataset['name']
    #points, target = points.to('cuda'), target.to('cuda')

    for i_batch, sample_batched in enumerate(dataloader):

        ### get batch
        data_list = []
        name_list = []
        for i, d in enumerate(sample_batched):
            if 'bvec' in d['points'].keys:
                d['points'].bvec += sample_size * i
            data_list.append(d['points'])
            name_list.append(d['name'])
        points = gBatch().from_data_list(data_list)
        if 'bvec' in points.keys:
            #points.batch = points.bvec.copy()
            points.batch = points.bvec.clone()
            del points.bvec
        #if 'bslices' in points.keys():
        #    points.__slices__ = torch.cum(
        target = points['y']
        if cfg['same_size']:
            points['lengths'] = points['lengths'][0].item()
        sample_batched = {'points': points, 'gt': target, 'name': name_list}
        #print('points:',points)

        #if (epoch != 0) and (epoch % 20 == 0):
        #    assert(len(dataloader.dataset) % int(cfg['fold_size']) == 0)
        #    folds = len(dataloader.dataset)/int(cfg['fold_size'])
        #    n_fold = (dataloader.dataset.n_fold + 1) % folds
        #    if n_fold != dataloader.dataset.n_fold:
        #        dataloader.dataset.n_fold = n_fold
        #        dataloader.dataset.load_fold()
        points, target = points.to('cuda'), target.to('cuda')
        #print(len(points.lengths),target.shape)

        ### initialize gradients
        #if not cfg['accumulation_interval'] or i_batch == 0:
        optimizer.zero_grad()

        ### forward
        logits = classifier(points)
        ### minimize the loss
        pred = F.log_softmax(logits, dim=-1)
        pred = pred.view(-1, num_classes)
        pred_choice = pred.data.max(1)[1].int()

        loss = F.nll_loss(pred, target.long())

        ep_loss += loss
        #print('memory allocated in MB: ', torch.cuda.memory_allocated()/2**20)
        #import sys; sys.exit()
        loss.backward()

        #if int(cfg['accumulation_interval']) % (i_batch+1) == 0:
        optimizer.step()
        #optimizer.zero_grad
        #elif not cfg['accumulation_interval']:
        #    optimizer.step()

        ### compute performance
        correct = pred_choice.eq(target.data.int()).sum()
        acc = correct.item() / float(target.size(0))

        tp = torch.mul(pred_choice.data,
                       target.data.int()).sum().item() + 0.00001
        fp = pred_choice.gt(target.data.int()).sum().item()
        fn = pred_choice.lt(target.data.int()).sum().item()
        tn = correct.item() - tp
        iou = float(tp) / (tp + fp + fn)
        prec = float(tp) / (tp + fp)
        recall = float(tp) / (tp + fn)

        print('[%d: %d/%d] train loss: %f acc: %f iou: %f' \
              % (epoch, i_batch, num_batch, loss.item(), acc, iou))

        mean_prec = torch.cat((mean_prec, torch.tensor([prec])), 0)
        mean_recall = torch.cat((mean_recall, torch.tensor([recall])), 0)
        mean_acc = torch.cat((mean_acc, torch.tensor([acc])), 0)
        mean_iou = torch.cat((mean_iou, torch.tensor([iou])), 0)
        n_iter += 1

    writer.add_scalar('train/epoch_loss', ep_loss / (i_batch + 1), epoch)

    return mean_acc, mean_prec, mean_iou, mean_recall, ep_loss / (i_batch +
                                                                  1), n_iter