Exemplo n.º 1
0
def train(cfg):

    num_classes = int(cfg['n_classes'])
    batch_size = int(cfg['batch_size'])
    n_epochs = int(cfg['n_epochs'])
    sample_size = int(cfg['fixed_size'])
    cfg['loss'] = cfg['loss'].split(' ')
    input_size = int(cfg['data_dim'])

    #### DATA LOADING
    trans_train = []
    trans_val = []
    if cfg['rnd_sampling']:
        trans_train.append(ds.RndSampling(sample_size, maintain_prop=False))
        trans_val.append(ds.RndSampling(sample_size, maintain_prop=False))
    if cfg['standardization']:
        trans_train.append(ds.SampleStandardization())
        trans_val.append(ds.SampleStandardization())
    #trans_train.append(T.Distance(norm=False))
    #trans_val.append(T.Distance(norm=False))

    if cfg['dataset'] == 'hcp20_graph':
        dataset = ds.HCP20Dataset(
            cfg['sub_list_train'],
            cfg['dataset_dir'],
            #k=4,
            act=cfg['act'],
            transform=transforms.Compose(trans_train),
            #self_loops=T.AddSelfLoops(),
            #distance=T.Distance(norm=True,cat=False),
            return_edges=True,
            load_one_full_subj=False)
    elif cfg['dataset'] == 'left_ifof_ss_sl_graph':
        dataset = ds.LeftIFOFSupersetGraphDataset(
            cfg['sub_list_train'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_train),
            same_size=cfg['same_size'])
    elif cfg['dataset'] == 'tractseg_500k':
        dataset = ds.Tractseg500kDataset(
            cfg['sub_list_train'],
            cfg['dataset_dir'],
            act=cfg['act'],
            #fold_size=int(cfg['fold_size']),
            transform=transforms.Compose(trans_train))
    if 'graph' in cfg['dataset']:
        DL = gDataLoader
    else:
        DL = DataLoader

    dataloader = DL(dataset,
                    batch_size=batch_size,
                    shuffle=cfg['shuffling'],
                    num_workers=int(cfg['n_workers']),
                    pin_memory=True)

    print("Dataset %s loaded, found %d samples" %
          (cfg['dataset'], len(dataset)))
    if cfg['val_in_train']:
        if cfg['dataset'] == 'hcp20_graph':
            val_dataset = ds.HCP20Dataset(
                cfg['sub_list_val'],
                cfg['val_dataset_dir'],
                #k=4,
                act=cfg['act'],
                transform=transforms.Compose(trans_val),
                #distance=T.Distance(norm=True,cat=False),
                #self_loops=T.AddSelfLoops(),
                return_edges=True,
                load_one_full_subj=False)
        elif cfg['dataset'] == 'tractseg_500k':
            val_dataset = ds.Tractseg500kDataset(
                cfg['sub_list_val'],
                cfg['val_dataset_dir'],
                act=cfg['act'],
                #fold_size=int(cfg['fold_size']),
                transform=transforms.Compose(trans_val))
        elif cfg['dataset'] == 'left_ifof_ss_sl_graph':
            val_dataset = ds.LeftIFOFSupersetGraphDataset(
                cfg['sub_list_val'],
                cfg['dataset_dir'],
                transform=transforms.Compose(trans_val),
                same_size=cfg['same_size'])

        val_dataloader = DL(val_dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=int(cfg['n_workers']),
                            pin_memory=True)
        print("Validation dataset loaded, found %d samples" %
              (len(val_dataset)))

    # summary for tensorboard
    if cfg['experiment_name'] != 'default':
        for ext in range(100):
            exp_name = cfg['experiment_name'] + '_%d' % ext
            logdir = 'runs/%s' % exp_name
            if not os.path.exists(logdir):
                writer = SummaryWriter(logdir=logdir)
                break
    else:
        writer = SummaryWriter()

    tb_log_name = glob.glob('%s/events*' % logdir)[0].rsplit('/', 1)[1]
    tb_log_dir = 'tb_logs/%s' % exp_name
    os.system('mkdir -p %s' % tb_log_dir)
    os.system('ln -sr %s/%s %s/%s ' %
              (logdir, tb_log_name, tb_log_dir, tb_log_name))

    os.system('cp main_dsl_config.py %s/config.txt' % (writer.logdir))

    #### BUILD THE MODEL
    classifier = get_model(cfg)

    #### SET THE TRAINING
    if cfg['optimizer'] == 'sgd_momentum':
        optimizer = optim.SGD(classifier.parameters(),
                              lr=float(cfg['learning_rate']),
                              momentum=float(cfg['momentum']),
                              weight_decay=float(cfg['weight_decay']))
    elif cfg['optimizer'] == 'adam':
        optimizer = optim.Adam(classifier.parameters(),
                               lr=float(cfg['learning_rate']),
                               weight_decay=float(cfg['weight_decay']))

    if cfg['lr_type'] == 'step':
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                 int(cfg['lr_ep_step']),
                                                 gamma=float(cfg['lr_gamma']))
    elif cfg['lr_type'] == 'plateau':
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=float(cfg['lr_gamma']),
            patience=int(cfg['patience']),
            threshold=0.0001,
            min_lr=float(cfg['min_lr']))
    if cfg['loss'] == 'nll':
        loss_fn = F.nll_loss

    alfa = 0
    cluster_loss_fn = None

    classifier.cuda()
    num_batch = len(dataset) / batch_size
    print('num of batches per epoch: %d' % num_batch)
    cfg['num_batch'] = num_batch

    n_iter = 0
    best_pred = 0
    best_epoch = 0
    current_lr = float(cfg['learning_rate'])
    for epoch in range(n_epochs + 1):

        # bn decay as in pointnet orig
        if cfg['bn_decay'] and epoch % int(cfg['bn_decay_step']) == 0:
            bn_momentum = float(cfg['bn_decay_init']) * float(
                cfg['bn_decay_gamma'])**(epoch / int(cfg['bn_decay_step']))
            bn_momentum = 1 - min(0.99, 1 - bn_momentum)
            print('updated bn momentum to %f' % bn_momentum)
            for module in classifier.modules():
                if type(module) == torch.nn.BatchNorm1d:
                    module.momentum = bn_momentum

        mean_acc, mean_prec, mean_iou, mean_recall, loss, n_iter = train_iter(
            cfg, dataloader, classifier, optimizer, writer, epoch, n_iter,
            cluster_loss_fn)

        ### validation during training
        if epoch % int(cfg['val_freq']) == 0 and cfg['val_in_train']:
            best_epoch, best_pred, loss_val = val_iter(cfg, val_dataloader,
                                                       classifier, writer,
                                                       epoch, cluster_loss_fn,
                                                       best_epoch, best_pred,
                                                       logdir)
        if cfg['lr_type'] == 'step' and current_lr >= float(cfg['min_lr']):
            lr_scheduler.step()
        if cfg['lr_type'] == 'plateau':
            lr_scheduler.step(loss_val)
        for i, param_group in enumerate(optimizer.param_groups):
            current_lr = float(param_group['lr'])
        writer.add_scalar('train/lr', current_lr, epoch)

        ### logging
        writer.add_scalar('train/epoch_acc',
                          torch.mean(mean_acc).item(), epoch)
        writer.add_scalar('train/epoch_iou',
                          torch.mean(mean_iou).item(), epoch)
        writer.add_scalar('train/epoch_prec',
                          torch.mean(mean_prec).item(), epoch)
        writer.add_scalar('train/epoch_recall',
                          torch.mean(mean_recall).item(), epoch)

    writer.close()

    if best_epoch != n_epochs:
        if cfg['save_model']:
            modeldir = os.path.join(logdir, cfg['model_dir'])
            torch.save(classifier.state_dict(),
                       '%s/model_ep-%d.pth' % (modeldir, epoch))

        if cfg['save_pred']:
            pred_dir = writer.logdir + '/predictions_%d' % epoch
            if not os.path.exists(pred_dir):
                os.makedirs(pred_dir)
            for filename, value in pred_buffer.iteritems():
                with open(os.path.join(pred_dir, filename) + '.pkl',
                          'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

        if cfg['save_softmax_out']:
            sm_dir = writer.logdir + '/sm_out_%d' % epoch
            if not os.path.exists(sm_dir):
                os.makedirs(sm_dir)
            for filename, value in sm_buffer.iteritems():
                with open(os.path.join(sm_dir, filename) + '_sm_1.pkl',
                          'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)
            for filename, value in sm2_buffer.iteritems():
                with open(os.path.join(sm_dir, filename) + '_sm_2.pkl',
                          'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

        if cfg['save_gf']:
            gf_dir = writer.logdir + '/gf_%d' % epoch
            if not os.path.exists(gf_dir):
                os.makedirs(gf_dir)
            i = 0
            for filename, value in gf_buffer.iteritems():
                if i == 3:
                    break
                i += 1
                with open(os.path.join(gf_dir, filename) + '.pkl', 'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)
Exemplo n.º 2
0
def test(cfg):
    num_classes = int(cfg['n_classes'])
    sample_size = int(cfg['fixed_size'])
    cfg['loss'] = cfg['loss'].split(' ')
    batch_size = 1
    cfg['batch_size'] = batch_size
    epoch = eval(cfg['n_epochs'])
    #n_gf = int(cfg['num_gf'])
    input_size = int(cfg['data_dim'])

    trans_val = []
    if cfg['rnd_sampling']:
        trans_val.append(ds.TestSampling(sample_size))
    if cfg['standardization']:
        trans_val.append(ds.SampleStandardization())

    if cfg['dataset'] == 'hcp20_graph':
        dataset = ds.HCP20Dataset(
            cfg['sub_list_test'],
            cfg['dataset_dir'],
            act=cfg['act'],
            transform=transforms.Compose(trans_val),
            with_gt=cfg['with_gt'],
            #distance=T.Distance(norm=True,cat=False),
            return_edges=True,
            split_obj=True,
            train=False,
            load_one_full_subj=False)

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0)
    print("Validation dataset loaded, found %d samples" % (len(dataset)))

    for ext in range(100):
        logdir = '%s/test_%d' % (cfg['exp_path'], ext)
        if not os.path.exists(logdir):
            break
    writer = SummaryWriter(logdir)
    if cfg['weights_path'] == '':
        cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0]
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])
    elif 'ep-' in cfg['weights_path']:
        epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0])

    tb_log_name = glob.glob('%s/events*' % writer.logdir)[0].rsplit('/', 1)[1]
    tb_log_dir = 'tb_logs/%s' % logdir.split('/', 1)[1]
    os.system('mkdir -p %s' % tb_log_dir)
    os.system('ln -sr %s/%s %s/%s ' %
              (writer.logdir, tb_log_name, tb_log_dir, tb_log_name))

    #### BUILD THE MODEL
    classifier = get_model(cfg)

    classifier.cuda()
    classifier.load_state_dict(torch.load(cfg['weights_path']))
    classifier.eval()

    with torch.no_grad():
        pred_buffer = {}
        sm_buffer = {}
        sm2_buffer = {}
        gf_buffer = {}
        emb_buffer = {}
        print('\n\n')
        mean_val_acc = torch.tensor([])
        mean_val_iou = torch.tensor([])
        mean_val_prec = torch.tensor([])
        mean_val_recall = torch.tensor([])

        if 'split_obj' in dir(dataset) and dataset.split_obj:
            split_obj = True
        else:
            split_obj = False
            dataset.transform = []

        if split_obj:
            consumed = False
        else:
            consumed = True
        j = 0
        visualized = 0
        new_obj_read = True
        sls_count = 1
        while j < len(dataset):
            #while sls_count <= len(dataset):
            data = dataset[j]

            if split_obj:
                if new_obj_read:
                    obj_pred_choice = torch.zeros(data['obj_full_size'],
                                                  dtype=torch.int).cuda()
                    obj_target = torch.zeros(data['obj_full_size'],
                                             dtype=torch.int).cuda()
                    new_obj_read = False

                if len(dataset.remaining[j]) == 0:
                    consumed = True

            sample_name = data['name'] if type(
                data['name']) == str else data['name'][0]

            #print(points)
            #if len(points.shape()) == 2:
            #points = points.unsqueeze(0)
            #print(data)
            points = gBatch().from_data_list([data['points']])
            #points = data['points']
            if 'bvec' in points.keys:
                points.batch = points.bvec.clone()
                del points.bvec
            if cfg['with_gt']:
                target = points['y']
                target = target.to('cuda')
                target = target.view(-1, 1)[:, 0]
            if cfg['same_size']:
                points['lengths'] = points['lengths'][0].item()
            #if cfg['model'] == 'pointnet_cls':
            #points = points.view(len(data['obj_idxs']), -1, input_size)
            points = points.to('cuda')

            logits = classifier(points)
            logits = logits.view(-1, num_classes)

            if split_obj:
                obj_pred_choice[data['obj_idxs']] = pred_choice
                obj_target[data['obj_idxs']] = target.int()
                #if cfg['save_embedding']:
                #    obj_embedding[data['obj_idxs']] = classifier.embedding.squeeze()
            else:
                obj_data = points
                obj_pred_choice = pred_choice
                obj_target = target
                if cfg['save_embedding']:
                    obj_embedding = classifier.embedding.squeeze()

            if cfg['with_gt'] and consumed:
                print('val max class red ', obj_pred_choice.max().item())
                print('val min class pred ', obj_pred_choice.min().item())
                y_pred = obj_pred_choice.cpu().numpy()
                np.save(data['dir'] + '/y_pred_pointnet', y_pred)
                y_test = obj_target.cpu().numpy()
                np.save(data['dir'] + '/y_test_pointnet', y_test)
                #np.save(data['dir']+'/streamlines_lstm_GIN',streamlines)
                correct = obj_pred_choice.eq(obj_target.data.int()).cpu().sum()
                acc = correct.item() / float(obj_target.size(0))

                tp = torch.mul(
                    obj_pred_choice.data,
                    obj_target.data.int()).cpu().sum().item() + 0.00001
                fp = obj_pred_choice.gt(
                    obj_target.data.int()).cpu().sum().item()
                fn = obj_pred_choice.lt(
                    obj_target.data.int()).cpu().sum().item()
                tn = correct.item() - tp
                iou = torch.tensor([float(tp) / (tp + fp + fn)])
                prec = torch.tensor([float(tp) / (tp + fp)])
                recall = torch.tensor([float(tp) / (tp + fn)])

                mean_val_prec = torch.cat((mean_val_prec, prec), 0)
                mean_val_recall = torch.cat((mean_val_recall, recall), 0)
                mean_val_iou = torch.cat((mean_val_iou, iou), 0)
                mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])),
                                         0)
                print('VALIDATION [%d: %d/%d] val accuracy: %f' \
                        % (epoch, j, len(dataset), acc))

            if cfg['save_pred'] and consumed:
                print('buffering prediction %s' % sample_name)
                sl_idx = np.where(obj_pred.data.cpu().view(-1).numpy() == 1)[0]
                pred_buffer[sample_name] = sl_idx.tolist()

            if consumed:
                print(j)
                j += 1
                if split_obj:
                    consumed = False
                    new_obj_read = True

        macro_iou = torch.mean(mean_val_iou)
        macro_prec = torch.mean(mean_val_prec)
        macro_recall = torch.mean(mean_val_recall)

        epoch_iou = macro_iou.item()

    if cfg['save_pred']:
        #os.system('rm -r %s/predictions_test*' % writer.logdir)
        pred_dir = writer.logdir + '/predictions_test_%d' % epoch
        if not os.path.exists(pred_dir):
            os.makedirs(pred_dir)
        print('saving files')
        for filename, value in pred_buffer.items():
            with open(os.path.join(pred_dir, filename) + '.pkl', 'wb') as f:
                pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

    if cfg['with_gt']:
        print('TEST ACCURACY: %f' % torch.mean(mean_val_acc).item())
        print('TEST PRECISION: %f' % macro_prec.item())
        print('TEST RECALL: %f' % macro_recall.item())
        print('TEST IOU: %f' % macro_iou.item())
        mean_val_dsc = mean_val_prec * mean_val_recall * 2 / (mean_val_prec +
                                                              mean_val_recall)
        final_scores_file = writer.logdir + '/final_scores_test_%d.txt' % epoch
        scores_file = writer.logdir + '/scores_test_%d.txt' % epoch
        if not cfg['multi_category']:
            print('saving scores')
            with open(scores_file, 'w') as f:
                f.write('acc\n')
                f.writelines('%f\n' % v for v in mean_val_acc.tolist())
                f.write('prec\n')
                f.writelines('%f\n' % v for v in mean_val_prec.tolist())
                f.write('recall\n')
                f.writelines('%f\n' % v for v in mean_val_recall.tolist())
                f.write('dsc\n')
                f.writelines('%f\n' % v for v in mean_val_dsc.tolist())
                f.write('iou\n')
                f.writelines('%f\n' % v for v in mean_val_iou.tolist())
            with open(final_scores_file, 'w') as f:
                f.write('acc\n')
                f.write('%f\n' % mean_val_acc.mean())
                f.write('%f\n' % mean_val_acc.std())
                f.write('prec\n')
                f.write('%f\n' % mean_val_prec.mean())
                f.write('%f\n' % mean_val_prec.std())
                f.write('recall\n')
                f.write('%f\n' % mean_val_recall.mean())
                f.write('%f\n' % mean_val_recall.std())
                f.write('dsc\n')
                f.write('%f\n' % mean_val_dsc.mean())
                f.write('%f\n' % mean_val_dsc.std())
                f.write('iou\n')
                f.write('%f\n' % mean_val_iou.mean())
                f.write('%f\n' % mean_val_iou.std())

    print('\n\n')