Exemplo n.º 1
0
def train(cfg):

    num_classes = int(cfg['n_classes'])
    batch_size = int(cfg['batch_size'])
    n_epochs = int(cfg['n_epochs'])
    sample_size = int(cfg['fixed_size'])
    cfg['loss'] = cfg['loss'].split(' ')
    input_size = int(cfg['data_dim'])

    #### DATA LOADING
    trans_train = []
    trans_val = []
    if cfg['rnd_sampling']:
        trans_train.append(ds.RndSampling(sample_size, maintain_prop=False))
        trans_val.append(ds.RndSampling(sample_size, maintain_prop=False))
    if cfg['standardization']:
        trans_train.append(ds.SampleStandardization())
        trans_val.append(ds.SampleStandardization())
    #trans_train.append(T.Distance(norm=False))
    #trans_val.append(T.Distance(norm=False))

    if cfg['dataset'] == 'hcp20_graph':
        dataset = ds.HCP20Dataset(
            cfg['sub_list_train'],
            cfg['dataset_dir'],
            #k=4,
            act=cfg['act'],
            transform=transforms.Compose(trans_train),
            #self_loops=T.AddSelfLoops(),
            #distance=T.Distance(norm=True,cat=False),
            return_edges=True,
            load_one_full_subj=False)
    elif cfg['dataset'] == 'left_ifof_ss_sl_graph':
        dataset = ds.LeftIFOFSupersetGraphDataset(
            cfg['sub_list_train'],
            cfg['dataset_dir'],
            transform=transforms.Compose(trans_train),
            same_size=cfg['same_size'])
    elif cfg['dataset'] == 'tractseg_500k':
        dataset = ds.Tractseg500kDataset(
            cfg['sub_list_train'],
            cfg['dataset_dir'],
            act=cfg['act'],
            #fold_size=int(cfg['fold_size']),
            transform=transforms.Compose(trans_train))
    if 'graph' in cfg['dataset']:
        DL = gDataLoader
    else:
        DL = DataLoader

    dataloader = DL(dataset,
                    batch_size=batch_size,
                    shuffle=cfg['shuffling'],
                    num_workers=int(cfg['n_workers']),
                    pin_memory=True)

    print("Dataset %s loaded, found %d samples" %
          (cfg['dataset'], len(dataset)))
    if cfg['val_in_train']:
        if cfg['dataset'] == 'hcp20_graph':
            val_dataset = ds.HCP20Dataset(
                cfg['sub_list_val'],
                cfg['val_dataset_dir'],
                #k=4,
                act=cfg['act'],
                transform=transforms.Compose(trans_val),
                #distance=T.Distance(norm=True,cat=False),
                #self_loops=T.AddSelfLoops(),
                return_edges=True,
                load_one_full_subj=False)
        elif cfg['dataset'] == 'tractseg_500k':
            val_dataset = ds.Tractseg500kDataset(
                cfg['sub_list_val'],
                cfg['val_dataset_dir'],
                act=cfg['act'],
                #fold_size=int(cfg['fold_size']),
                transform=transforms.Compose(trans_val))
        elif cfg['dataset'] == 'left_ifof_ss_sl_graph':
            val_dataset = ds.LeftIFOFSupersetGraphDataset(
                cfg['sub_list_val'],
                cfg['dataset_dir'],
                transform=transforms.Compose(trans_val),
                same_size=cfg['same_size'])

        val_dataloader = DL(val_dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=int(cfg['n_workers']),
                            pin_memory=True)
        print("Validation dataset loaded, found %d samples" %
              (len(val_dataset)))

    # summary for tensorboard
    if cfg['experiment_name'] != 'default':
        for ext in range(100):
            exp_name = cfg['experiment_name'] + '_%d' % ext
            logdir = 'runs/%s' % exp_name
            if not os.path.exists(logdir):
                writer = SummaryWriter(logdir=logdir)
                break
    else:
        writer = SummaryWriter()

    tb_log_name = glob.glob('%s/events*' % logdir)[0].rsplit('/', 1)[1]
    tb_log_dir = 'tb_logs/%s' % exp_name
    os.system('mkdir -p %s' % tb_log_dir)
    os.system('ln -sr %s/%s %s/%s ' %
              (logdir, tb_log_name, tb_log_dir, tb_log_name))

    os.system('cp main_dsl_config.py %s/config.txt' % (writer.logdir))

    #### BUILD THE MODEL
    classifier = get_model(cfg)

    #### SET THE TRAINING
    if cfg['optimizer'] == 'sgd_momentum':
        optimizer = optim.SGD(classifier.parameters(),
                              lr=float(cfg['learning_rate']),
                              momentum=float(cfg['momentum']),
                              weight_decay=float(cfg['weight_decay']))
    elif cfg['optimizer'] == 'adam':
        optimizer = optim.Adam(classifier.parameters(),
                               lr=float(cfg['learning_rate']),
                               weight_decay=float(cfg['weight_decay']))

    if cfg['lr_type'] == 'step':
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                 int(cfg['lr_ep_step']),
                                                 gamma=float(cfg['lr_gamma']))
    elif cfg['lr_type'] == 'plateau':
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=float(cfg['lr_gamma']),
            patience=int(cfg['patience']),
            threshold=0.0001,
            min_lr=float(cfg['min_lr']))
    if cfg['loss'] == 'nll':
        loss_fn = F.nll_loss

    alfa = 0
    cluster_loss_fn = None

    classifier.cuda()
    num_batch = len(dataset) / batch_size
    print('num of batches per epoch: %d' % num_batch)
    cfg['num_batch'] = num_batch

    n_iter = 0
    best_pred = 0
    best_epoch = 0
    current_lr = float(cfg['learning_rate'])
    for epoch in range(n_epochs + 1):

        # bn decay as in pointnet orig
        if cfg['bn_decay'] and epoch % int(cfg['bn_decay_step']) == 0:
            bn_momentum = float(cfg['bn_decay_init']) * float(
                cfg['bn_decay_gamma'])**(epoch / int(cfg['bn_decay_step']))
            bn_momentum = 1 - min(0.99, 1 - bn_momentum)
            print('updated bn momentum to %f' % bn_momentum)
            for module in classifier.modules():
                if type(module) == torch.nn.BatchNorm1d:
                    module.momentum = bn_momentum

        mean_acc, mean_prec, mean_iou, mean_recall, loss, n_iter = train_iter(
            cfg, dataloader, classifier, optimizer, writer, epoch, n_iter,
            cluster_loss_fn)

        ### validation during training
        if epoch % int(cfg['val_freq']) == 0 and cfg['val_in_train']:
            best_epoch, best_pred, loss_val = val_iter(cfg, val_dataloader,
                                                       classifier, writer,
                                                       epoch, cluster_loss_fn,
                                                       best_epoch, best_pred,
                                                       logdir)
        if cfg['lr_type'] == 'step' and current_lr >= float(cfg['min_lr']):
            lr_scheduler.step()
        if cfg['lr_type'] == 'plateau':
            lr_scheduler.step(loss_val)
        for i, param_group in enumerate(optimizer.param_groups):
            current_lr = float(param_group['lr'])
        writer.add_scalar('train/lr', current_lr, epoch)

        ### logging
        writer.add_scalar('train/epoch_acc',
                          torch.mean(mean_acc).item(), epoch)
        writer.add_scalar('train/epoch_iou',
                          torch.mean(mean_iou).item(), epoch)
        writer.add_scalar('train/epoch_prec',
                          torch.mean(mean_prec).item(), epoch)
        writer.add_scalar('train/epoch_recall',
                          torch.mean(mean_recall).item(), epoch)

    writer.close()

    if best_epoch != n_epochs:
        if cfg['save_model']:
            modeldir = os.path.join(logdir, cfg['model_dir'])
            torch.save(classifier.state_dict(),
                       '%s/model_ep-%d.pth' % (modeldir, epoch))

        if cfg['save_pred']:
            pred_dir = writer.logdir + '/predictions_%d' % epoch
            if not os.path.exists(pred_dir):
                os.makedirs(pred_dir)
            for filename, value in pred_buffer.iteritems():
                with open(os.path.join(pred_dir, filename) + '.pkl',
                          'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

        if cfg['save_softmax_out']:
            sm_dir = writer.logdir + '/sm_out_%d' % epoch
            if not os.path.exists(sm_dir):
                os.makedirs(sm_dir)
            for filename, value in sm_buffer.iteritems():
                with open(os.path.join(sm_dir, filename) + '_sm_1.pkl',
                          'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)
            for filename, value in sm2_buffer.iteritems():
                with open(os.path.join(sm_dir, filename) + '_sm_2.pkl',
                          'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)

        if cfg['save_gf']:
            gf_dir = writer.logdir + '/gf_%d' % epoch
            if not os.path.exists(gf_dir):
                os.makedirs(gf_dir)
            i = 0
            for filename, value in gf_buffer.iteritems():
                if i == 3:
                    break
                i += 1
                with open(os.path.join(gf_dir, filename) + '.pkl', 'wb') as f:
                    pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL)
Exemplo n.º 2
0
def train(cfg):

    batch_size = int(cfg['batch_size'])
    n_epochs = int(cfg['n_epochs'])
    sample_size = int(cfg['fixed_size'])
    cfg['loss'] = cfg['loss'].split(' ')

    #### DATA LOADING
    trans_train = []
    trans_val = []
    if cfg['rnd_sampling']:
        trans_train.append(
            ds.RndSampling(sample_size,
                           maintain_prop=False,
                           prop_vector=[1, 1]))
        trans_val.append(ds.RndSampling(sample_size, maintain_prop=False))

    dataset, dataloader = get_dataset(cfg, trans=trans_train)
    val_dataset, val_dataloader = get_dataset(cfg, trans=trans_val, train=False)

    # summary for tensorboard
    writer = create_tb_logger(cfg)
    dump_code(cfg, writer.logdir)

    #### BUILD THE MODEL
    classifier = get_model(cfg)

    #### SET THE TRAINING
    optimizer = get_optimizer(cfg, classifier)

    lr_scheduler = get_lr_scheduler(cfg, optimizer)

    classifier.cuda()

    num_batch = len(dataset) / batch_size
    print('num of batches per epoch: %d' % num_batch)
    cfg['num_batch'] = num_batch

    n_iter = 0
    best_pred = 0
    best_epoch = 0
    current_lr = float(cfg['learning_rate'])
    for epoch in range(n_epochs + 1):

        # update bn decay
        if cfg['bn_decay'] and epoch != 0 and epoch % int(
                cfg['bn_decay_step']) == 0:
            update_bn_decay(cfg, classifier, epoch)

        loss, n_iter = train_ep(cfg, dataloader, classifier, optimizer, writer,
                                epoch, n_iter)

        ### validation during training
        if epoch % int(cfg['val_freq']) == 0 and cfg['val_in_train']:
            best_epoch, best_pred = val_ep(cfg, val_dataloader, classifier,
                                           writer, epoch, best_epoch, best_pred)

        # update lr
        if cfg['lr_type'] == 'step' and current_lr >= float(cfg['min_lr']):
            lr_scheduler.step()
        if cfg['lr_type'] == 'plateau':
            lr_scheduler.step(loss)

        current_lr = get_lr(optimizer)
        writer.add_scalar('train/lr', current_lr, epoch)

    writer.close()