예제 #1
0
파일: evaluate.py 프로젝트: yangxu351/mrs
def main():
    device, _ = misc_utils.set_gpu(GPU)

    # init model
    args = network_io.load_config(MODEL_DIR)
    model = network_io.create_model(args)
    if LOAD_EPOCH:
        args['trainer']['epochs'] = LOAD_EPOCH
    ckpt_dir = os.path.join(
        MODEL_DIR, 'epoch-{}.pth.tar'.format(args['trainer']['epochs']))
    network_utils.load(model, ckpt_dir)
    print('Loaded from {}'.format(ckpt_dir))
    model.to(device)
    model.eval()

    # eval on dataset
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    tsfm_valid = A.Compose([
        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])
    save_dir = os.path.join(
        r'/home/wh145/results/mrs/mass_roads',
        os.path.basename(network_utils.unique_model_name(args)))
    evaluator = eval_utils.Evaluator('mnih', DATA_DIR, tsfm_valid, device)
    evaluator.evaluate(model,
                       PATCHS_SIZE,
                       2 * model.lbl_margin,
                       pred_dir=save_dir,
                       report_dir=save_dir)
예제 #2
0
def main():
    config = json.load(open(os.path.join(MODEL_DIR, 'config.json')))

    # set gpu
    device, parallel = misc_utils.set_gpu(GPU)
    model = StackMTLNet.StackHourglassNetMTL(config['task1_classes'],
                                             config['task2_classes'],
                                             config['backbone'])
    network_utils.load(model,
                       os.path.join(MODEL_DIR,
                                    'epoch-{}.pth.tar'.format(EPOCH_NUM)),
                       disable_parallel=True)
    model.to(device)
    model.eval()

    # eval on dataset
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    tsfm_valid = A.Compose([
        A.Normalize(mean=mean, std=std),
        ToTensor(sigmoid=False),
    ])
    save_dir = os.path.join(r'/hdd/Results/line_mtl_cust',
                            os.path.basename(MODEL_DIR))
    evaluator = network_utils.Evaluator('transmission', tsfm_valid, device)
    evaluator.evaluate(model, (512, 512),
                       0,
                       pred_dir=save_dir,
                       report_dir=save_dir,
                       save_conf=True)
예제 #3
0
def easy_load(model_dir, epoch=None):
    """
    Initialize and define model based on their corresponding configuration file
    :param model_dir: directory of the saved model
    :param epoch: number of epoch to load
    :return:
    """
    config = load_config(model_dir)
    model = create_model(config)
    if epoch:
        load_epoch = epoch
    else:
        load_epoch = config.epochs
    pretrained_dir = os.path.join(model_dir,
                                  'epoch-{}.pth.tar'.format(load_epoch - 1))
    network_utils.load(model, pretrained_dir)
    print('Loaded model from {} @ epoch {}'.format(model_dir, load_epoch))
    return model
예제 #4
0
def train_model(args, device, parallel):
    # TODO more options of network
    model = StackMTLNet.StackHourglassNetMTL(args['task1_classes'],
                                             args['task2_classes'],
                                             args['backbone'])
    log_dir = os.path.join(args['trainer']['save_dir'], 'log')
    writer = SummaryWriter(log_dir=log_dir)
    try:
        writer.add_graph(
            model, torch.rand(1, 3, *eval(args['dataset']['input_size'])))
    except (RuntimeError, TypeError):
        print(
            'Warning: could not write graph to tensorboard, this might be a bug in tensorboardX'
        )
    if parallel:
        model.encoder = nn.DataParallel(
            model.encoder,
            device_ids=[a for a in range(len(args['gpu'].split(',')))])
        model.decoder = nn.DataParallel(
            model.decoder,
            device_ids=[a for a in range(len(args['gpu'].split(',')))])

    start_epoch = 0
    if args['resume_dir'] != 'None':
        print('Resume training from {}'.format(args['resume_dir']))
        ckpt = torch.load(args['resume_dir'])
        start_epoch = ckpt['epoch']
        network_utils.load(model, args['resume_dir'], disable_parallel=True)
    elif args['finetune_dir'] != 'None':
        print('Finetune model from {}'.format(args['finetune_dir']))
        network_utils.load(model, args['finetune_dir'], disable_parallel=True)

    model.to(device)

    # make optimizer
    train_params = [{
        'params': model.encoder.parameters(),
        'lr': args['optimizer']['e_lr']
    }, {
        'params': model.decoder.parameters(),
        'lr': args['optimizer']['d_lr']
    }]
    optm = optim.SGD(train_params,
                     lr=args['optimizer']['e_lr'],
                     momentum=0.9,
                     weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(
        optm,
        milestones=eval(args['optimizer']['lr_drop_epoch']),
        gamma=args['optimizer']['lr_step'])
    angle_weights = torch.ones(args['task2_classes']).to(device)
    road_weights = torch.tensor(
        [1 - args['task1_classes'], args['task1_classes']],
        dtype=torch.float).to(device)
    angle_loss = metric_utils.CrossEntropyLoss2d(
        weight=angle_weights).to(device)
    road_loss = metric_utils.mIoULoss(weight=road_weights).to(device)
    iou_loss = metric_utils.IoU().to(device)

    # prepare training
    print('Total params: {:.2f}M'.format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    # make data loader
    mean = eval(args['dataset']['mean'])
    std = eval(args['dataset']['std'])
    tsfm_train = A.Compose([
        A.Flip(),
        A.RandomRotate90(),
        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])
    tsfm_valid = A.Compose([
        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])
    train_loader = DataLoader(loader.TransmissionDataLoader(
        args['dataset']['data_dir'],
        args['dataset']['train_file'],
        transforms=tsfm_train),
                              batch_size=args['dataset']['batch_size'],
                              shuffle=True,
                              num_workers=args['dataset']['workers'])
    valid_loader = DataLoader(loader.TransmissionDataLoader(
        args['dataset']['data_dir'],
        args['dataset']['valid_file'],
        transforms=tsfm_valid),
                              batch_size=args['dataset']['batch_size'],
                              shuffle=False,
                              num_workers=args['dataset']['workers'])
    print('Start training model')
    train_val_loaders = {'train': train_loader, 'valid': valid_loader}

    # train the model
    for epoch in range(start_epoch, args['trainer']['total_epochs']):
        for phase in ['train', 'valid']:
            start_time = timeit.default_timer()
            if phase == 'train':
                model.train()
                scheduler.step()
            else:
                model.eval()

            loss_dict = model.step(train_val_loaders[phase], device, optm,
                                   phase, road_loss, angle_loss, iou_loss,
                                   True, mean, std)
            misc_utils.write_and_print(writer, phase, epoch,
                                       args['trainer']['total_epochs'],
                                       loss_dict, start_time)

        # save the model
        if epoch % args['trainer']['save_epoch'] == (
                args['trainer']['save_epoch'] - 1):
            save_name = os.path.join(args['trainer']['save_dir'],
                                     'epoch-{}.pth.tar'.format(epoch))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'opt_dict': optm.state_dict(),
                    'loss': loss_dict,
                }, save_name)
            print('Saved model at {}'.format(save_name))
    writer.close()
예제 #5
0
def train_model(args, device, parallel):
    """
    The function to train the model
    :param args: the class carries configuration parameters defined in config.py
    :param device: the device to run the model
    :return:
    """

    model = network_io.create_model(args)
    log_dir = os.path.join(args['save_dir'], 'log')
    writer = SummaryWriter(log_dir=log_dir)
    # TODO add write_graph back, probably need to swith to tensorboard in pytorch
    if parallel:
        model.encoder = network_utils.DataParallelPassThrough(model.encoder)
        model.decoder = network_utils.DataParallelPassThrough(model.decoder)
        if args['optimizer']['aux_loss']:
            model.cls = network_utils.DataParallelPassThrough(model.cls)
        print('Parallel training mode enabled!')
    train_params = model.set_train_params(
        (args['optimizer']['learn_rate_encoder'],
         args['optimizer']['learn_rate_decoder']))

    # make optimizer
    optm = network_io.create_optimizer(args['optimizer']['name'], train_params,
                                       args['optimizer']['learn_rate_encoder'])
    criterions = network_io.create_loss(args, device=device)
    cls_criterion = None
    with_aux = False
    if args['optimizer']['aux_loss']:
        with_aux = True
        cls_criterion = metric_utils.BCEWithLogitLoss(
            device, eval(args['trainer']['class_weight']))
    scheduler = optim.lr_scheduler.MultiStepLR(
        optm,
        milestones=eval(args['optimizer']['decay_step']),
        gamma=args['optimizer']['decay_rate'])

    # if not resume, train from scratch
    if args['trainer']['resume_epoch'] == 0 and args['trainer'][
            'finetune_dir'] == 'None':
        print('Training decoder {} with encoder {} from scratch ...'.format(
            args['decoder_name'], args['encoder_name']))
    elif args['trainer']['resume_epoch'] == 0 and args['trainer'][
            'finetune_dir']:
        print('Finetuning model from {}'.format(
            args['trainer']['finetune_dir']))
        if args['trainer']['further_train']:
            network_utils.load(model,
                               args['trainer']['finetune_dir'],
                               relax_load=True,
                               optm=optm,
                               device=device)
        else:
            network_utils.load(model,
                               args['trainer']['finetune_dir'],
                               relax_load=True)
    else:
        print('Resume training decoder {} with encoder {} from epoch {} ...'.
              format(args['decoder_name'], args['encoder_name'],
                     args['trainer']['resume_epoch']))
        network_utils.load_epoch(args['save_dir'],
                                 args['trainer']['resume_epoch'], model, optm,
                                 device)

    # prepare training
    print('Total params: {:.2f}M'.format(network_utils.get_model_size(model)))
    model.to(device)
    for c in criterions:
        c.to(device)

    # make data loader
    ds_cfgs = [a for a in sorted(args.keys()) if 'dataset' in a]
    assert ds_cfgs[0] == 'dataset'

    train_val_loaders = {'train': [], 'valid': []}
    if args['dataset']['load_func'] == 'default':
        load_func = data_utils.default_get_stats
    else:
        load_func = None
    for ds_cfg in ds_cfgs:
        mean, std = network_io.get_dataset_stats(
            args[ds_cfg]['ds_name'],
            args[ds_cfg]['data_dir'],
            mean_val=(eval(args[ds_cfg]['mean']), eval(args[ds_cfg]['std'])),
            load_func=load_func,
            file_list=args[ds_cfg]['train_file'])
        tsfm_train, tsfm_valid = network_io.create_tsfm(args, mean, std)
        train_loader = DataLoader(
            data_loader.get_loader(args[ds_cfg]['data_dir'],
                                   args[ds_cfg]['train_file'],
                                   transforms=tsfm_train,
                                   n_class=args[ds_cfg]['class_num'],
                                   with_aux=with_aux),
            batch_size=int(args[ds_cfg]['batch_size']),
            shuffle=True,
            num_workers=int(args['dataset']['num_workers']),
            drop_last=True)
        train_val_loaders['train'].append(train_loader)

        if 'valid_file' in args[ds_cfg]:
            valid_loader = DataLoader(
                data_loader.get_loader(args[ds_cfg]['data_dir'],
                                       args[ds_cfg]['valid_file'],
                                       transforms=tsfm_valid,
                                       n_class=args[ds_cfg]['class_num'],
                                       with_aux=with_aux),
                batch_size=int(args[ds_cfg]['batch_size']),
                shuffle=False,
                num_workers=int(args[ds_cfg]['num_workers']))
            print('Training model on the {} dataset'.format(
                args[ds_cfg]['ds_name']))
            train_val_loaders['valid'].append(valid_loader)

    # train the model
    loss_dict = {}
    for epoch in range(int(args['trainer']['resume_epoch']),
                       int(args['trainer']['epochs'])):
        # each epoch has a training and validation step
        for phase in ['train', 'valid']:
            start_time = timeit.default_timer()
            if phase == 'train':
                model.train()
            else:
                model.eval()

            # TODO align aux loss and normal train
            loss_dict = model.step(
                train_val_loaders[phase],
                device,
                optm,
                phase,
                criterions,
                eval(args['trainer']['bp_loss_idx']),
                True,
                mean,
                std,
                loss_weights=eval(args['trainer']['loss_weights']),
                use_emau=args['use_emau'],
                use_ocr=args['use_ocr'],
                cls_criterion=cls_criterion,
                cls_weight=args['optimizer']['aux_loss_weight'])
            network_utils.write_and_print(writer, phase, epoch,
                                          int(args['trainer']['epochs']),
                                          loss_dict, start_time)

        scheduler.step()
        # save the model
        if epoch % int(args['trainer']['save_epoch']) == 0 and epoch != 0:
            save_name = os.path.join(args['save_dir'],
                                     'epoch-{}.pth.tar'.format(epoch))
            network_utils.save(model, epoch, optm, loss_dict, save_name)
    # save model one last time
    save_name = os.path.join(
        args['save_dir'],
        'epoch-{}.pth.tar'.format(int(args['trainer']['epochs'])))
    network_utils.save(model, int(args['trainer']['epochs']), optm, loss_dict,
                       save_name)
    writer.close()