Beispiel #1
0
def TrainValProgram(args):

    config = yaml.load(open(args.config, 'r', encoding='utf-8'),
                       Loader=yaml.FullLoader)
    config = merge_config(config, args)

    os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']
    create_dir(config['base']['checkpoints'])
    checkpoints = os.path.join(
        config['base']['checkpoints'], "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" %
        (config['base']['algorithm'],
         config['backbone']['function'].split(',')[-1],
         config['head']['function'].split(',')[-1],
         config['trainload']['batch_size'], config['base']['n_epoch'],
         args.log_str))
    create_dir(checkpoints)

    model = create_module(config['architectures']['model_function'])(config)
    criterion = create_module(config['architectures']['loss_function'])(config)
    train_dataset = create_module(config['trainload']['function'])(config)
    test_dataset = create_module(config['testload']['function'])(config)
    optimizer = create_module(config['optimizer']['function'])(config, model)
    optimizer_decay = create_module(config['optimizer_decay']['function'])
    img_process = create_module(config['postprocess']['function'])(config)

    if args.t_config is not None:
        t_model = GetTeacherModel(args)
        distil_loss = DistilLoss()
        if torch.cuda.is_available():
            distil_loss = distil_loss.cuda()

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['trainload']['batch_size'],
        shuffle=True,
        num_workers=config['trainload']['num_workers'],
        worker_init_fn=worker_init_fn,
        drop_last=True,
        pin_memory=True)

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['testload']['batch_size'],
        shuffle=False,
        num_workers=config['testload']['num_workers'],
        drop_last=True,
        pin_memory=True)

    use_distil = False
    if args.t_config is not None:
        use_distil = True
    loss_bin = create_loss_bin(config['base']['algorithm'], use_distil)

    if torch.cuda.is_available():
        if (len(config['base']['gpu_id'].split(',')) > 1):
            model = torch.nn.DataParallel(model).cuda()
        else:
            model = model.cuda()
        criterion = criterion.cuda()

    start_epoch = 0
    rescall, precision, hmean = 0, 0, 0
    best_rescall, best_precision, best_hmean = 0, 0, 0

    if args.pruned_model_dict_path is not None:
        print('finetune the pruend model.')
        model = load_prune_model(model, args.prune_model_path,
                                 args.pruned_model_dict_path, args.prune_type)
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'])
        title = list(loss_bin.keys())
        title.extend([
            'piexl_acc', 'piexl_iou', 't_rescall', 't_precision', 't_hmean',
            'b_rescall', 'b_precision', 'b_hmean'
        ])
        log_write.set_names(title)

    elif config['base']['restore']:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            config['base']['restore_file']), 'Error: no checkpoint file found!'
        checkpoint = torch.load(config['base']['restore_file'])
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_rescall = checkpoint['rescall']
        best_precision = checkpoint['precision']
        best_hmean = checkpoint['hmean']
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'],
                           resume=True)
    else:
        print('Training from scratch.')
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'])
        title = list(loss_bin.keys())
        title.extend([
            'piexl_acc', 'piexl_iou', 't_rescall', 't_precision', 't_hmean',
            'b_rescall', 'b_precision', 'b_hmean'
        ])
        log_write.set_names(title)

    if args.start_epoch is not None:
        start_epoch = args.start_epoch

    for epoch in range(start_epoch, config['base']['n_epoch']):
        model.train()
        if args.t_config is not None:
            t_model.train()
        else:
            t_model = None
            distil_loss = None
        optimizer_decay(config, optimizer, epoch)
        loss_write = ModelTrain(train_data_loader, t_model, distil_loss, model,
                                criterion, optimizer, loss_bin, args, config,
                                epoch)

        if (epoch >= config['base']['start_val']):
            create_dir(os.path.join(checkpoints, 'val'))
            create_dir(os.path.join(checkpoints, 'val', 'res_img'))
            create_dir(os.path.join(checkpoints, 'val', 'res_txt'))
            model.eval()
            rescall, precision, hmean = ModelEval(test_dataset,
                                                  test_data_loader, model,
                                                  img_process, checkpoints,
                                                  config)
            print('rescall:', rescall, 'precision', precision, 'hmean', hmean)
            if (hmean > best_hmean):
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'lr': config['optimizer']['base_lr'],
                        'optimizer': optimizer.state_dict(),
                        'hmean': hmean,
                        'rescall': rescall,
                        'precision': precision
                    }, checkpoints,
                    config['base']['algorithm'] + '_best' + '.pth.tar')
                best_hmean = hmean
                best_precision = precision
                best_rescall = rescall

        loss_write.extend([
            rescall, precision, hmean, best_rescall, best_precision, best_hmean
        ])
        log_write.append(loss_write)
        for key in loss_bin.keys():
            loss_bin[key].loss_clear()
        if epoch % config['base']['save_epoch'] == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': config['optimizer']['base_lr'],
                    'optimizer': optimizer.state_dict(),
                    'hmean': 0,
                    'rescall': 0,
                    'precision': 0
                }, checkpoints,
                config['base']['algorithm'] + '_' + str(epoch) + '.pth.tar')
Beispiel #2
0
def TrainValProgram(config):

    config = yaml.load(open(args.config, 'r', encoding='utf-8'),
                       Loader=yaml.FullLoader)
    config = merge_config(config, args)

    os.environ["CUDA_VISIBLE_DEVICES"] = config['base']['gpu_id']

    create_dir(config['base']['checkpoints'])
    checkpoints = os.path.join(
        config['base']['checkpoints'], "ag_%s_bb_%s_he_%s_bs_%d_ep_%d_%s" %
        (config['base']['algorithm'],
         config['backbone']['function'].split(',')[-1],
         config['head']['function'].split(',')[-1],
         config['trainload']['batch_size'], config['base']['n_epoch'],
         args.log_str))
    create_dir(checkpoints)

    LabelConverter = create_module(
        config['label_transform']['function'])(config)
    config['base']['classes'] = len(LabelConverter.alphabet)
    model = create_module(config['architectures']['model_function'])(config)
    criterion = create_module(config['architectures']['loss_function'])(config)
    train_dataset = create_module(config['trainload']['function'])(config)
    test_dataset = create_module(config['testload']['function'])(config)
    optimizer = create_module(config['optimizer']['function'])(config, model)
    optimizer_decay = create_module(config['optimizer_decay']['function'])

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['trainload']['batch_size'],
        shuffle=True,
        num_workers=config['trainload']['num_workers'],
        worker_init_fn=worker_init_fn,
        collate_fn=alignCollate(),
        drop_last=True,
        pin_memory=True)

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['testload']['batch_size'],
        shuffle=False,
        num_workers=config['testload']['num_workers'],
        collate_fn=alignCollate(),
        drop_last=True,
        pin_memory=True)

    loss_bin = create_loss_bin(config['base']['algorithm'])

    if torch.cuda.is_available():
        if (len(config['base']['gpu_id'].split(',')) > 1):
            model = torch.nn.DataParallel(model).cuda()
        else:
            model = model.cuda()
        criterion = criterion.cuda()

    start_epoch = 0
    val_acc = 0
    val_loss = 0
    best_acc = 0

    if config['base']['restore']:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            config['base']['restore_file']), 'Error: no checkpoint file found!'
        checkpoint = torch.load(config['base']['restore_file'])
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'],
                           resume=True)
    else:
        print('Training from scratch.')
        log_write = Logger(os.path.join(checkpoints, 'log.txt'),
                           title=config['base']['algorithm'])
        title = list(loss_bin.keys())
        title.extend(['val_loss', 'test_acc', 'best_acc'])
        log_write.set_names(title)

    for epoch in range(start_epoch, config['base']['n_epoch']):
        model.train()
        optimizer_decay(config, optimizer, epoch)
        loss_write = ModelTrain(train_data_loader, LabelConverter, model,
                                criterion, optimizer, loss_bin, config, epoch)
        if (epoch >= config['base']['start_val']):
            model.eval()
            val_acc, val_loss = ModelEval(test_data_loader, LabelConverter,
                                          model, criterion, config)
            print('val_acc:', val_acc, 'val_loss', val_loss)
            if (val_acc > best_acc):
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'lr': config['optimizer']['base_lr'],
                        'optimizer': optimizer.state_dict(),
                        'best_acc': val_acc
                    }, checkpoints,
                    config['base']['algorithm'] + '_best' + '.pth.tar')
                best_acc = val_acc

        loss_write.extend([val_loss, val_acc, best_acc])
        log_write.append(loss_write)
        for key in loss_bin.keys():
            loss_bin[key].loss_clear()
        if epoch % config['base']['save_epoch'] == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': config['optimizer']['base_lr'],
                    'optimizer': optimizer.state_dict(),
                    'best_acc': 0
                }, checkpoints,
                config['base']['algorithm'] + '_' + str(epoch) + '.pth.tar')