示例#1
0
def train(cfg):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    model = instantiate(cfg.model).cuda()
    optimizer = instantiate(cfg.optim, model.parameters())
    train_dataset = instantiate(cfg.train)
    test_dataset = instantiate(cfg.test)
    lr_policy = instantiate(cfg.optim.lr_policy)

    train_net(model=model,
              optimizer=optimizer,
              train_dataset=train_dataset,
              batch_size=cfg.train.batch_size,
              max_epoch=cfg.optim.max_epoch,
              loader_params=cfg.data_loader,
              lr_policy=lr_policy,
              save_period=cfg.train.checkpoint_period,
              weights=cfg.train.weights)

    err = test_net(model=model,
                   test_dataset=test_dataset,
                   batch_size=cfg.test.batch_size,
                   loader_params=cfg.data_loader)

    test_corrupt_net(model=model,
                     corrupt_cfg=cfg.corrupt,
                     batch_size=cfg.corrupt.batch_size,
                     loader_params=cfg.data_loader,
                     aug_string=cfg.corrupt.aug_string,
                     clean_err=err,
                     mCE_denom=cfg.corrupt.mCE_baseline_file)
示例#2
0
def train(cfg, is_leader):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    cur_device = torch.cuda.current_device()
    model = instantiate(cfg.model).cuda(device=cur_device)
    if cfg.num_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(
                module=model,
                device_ids=[cur_device],
                output_device=cur_device
                )
    optimizer = instantiate(cfg.optim, model.parameters())
    if cfg.optim.max_epoch > 0 and cfg.train.weights is None:
        print("Loading training set...")
        train_dataset = instantiate(cfg.train)
    else:
        print("Skipping loading the training dataset, 0 epochs of training to perform "
        " or pre-trained weights provided.")
        train_dataset = None
    print("Loading test set...")
    test_dataset = instantiate(cfg.test)
    lr_policy = instantiate(cfg.optim.lr_policy)  

    print("Training...")
    train_net(model=model,
            optimizer=optimizer,
            train_dataset=train_dataset,
            batch_size=cfg.train.batch_size,
            max_epoch=cfg.optim.max_epoch,
            loader_params=cfg.data_loader,
            lr_policy=lr_policy,
            save_period=cfg.train.checkpoint_period,
            weights=cfg.train.weights,
            num_gpus=cfg.num_gpus,
            is_leader=is_leader,
            jsd_num=cfg.train.params.jsd_num,
            jsd_alpha=cfg.train.jsd_alpha
            )

    print("Testing...")
    err = test_net(model=model,
            test_dataset=test_dataset,
            batch_size=cfg.test.batch_size,
            loader_params=cfg.data_loader,
            num_gpus=cfg.num_gpus)

    test_corrupt_net(model=model,
            corrupt_cfg=cfg.corrupt,
            batch_size=cfg.corrupt.batch_size,
            loader_params=cfg.data_loader,
            aug_string=cfg.corrupt.aug_string,
            clean_err=err,
            mCE_denom=cfg.corrupt.mCE_baseline_file,
            num_gpus=cfg.num_gpus,
            log_name='train_imagenet.log')
示例#3
0
def train(cfg):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    model = instantiate(cfg.model).cuda()
    test_dataset = instantiate(cfg.test)

    checkpoint = torch.load(cfg.weights, map_location='cpu')
    model.load_state_dict(checkpoint['model_state'])

    err = test_net(model=model,
                   test_dataset=test_dataset,
                   batch_size=cfg.test.batch_size,
                   loader_params=cfg.data_loader)

    test_corrupt_net(model=model,
                     corrupt_cfg=cfg.corrupt,
                     batch_size=cfg.corrupt.batch_size,
                     loader_params=cfg.data_loader,
                     aug_string=cfg.corrupt.aug_string,
                     clean_err=err,
                     mCE_denom=cfg.corrupt.mCE_baseline_file)
示例#4
0
def train(cfg, is_leader):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    cur_device = torch.cuda.current_device()
    model = instantiate(cfg.model).cuda(device=cur_device)
    if cfg.num_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(
            module=model, device_ids=[cur_device], output_device=cur_device)
    print("Loading test set...")
    test_dataset = instantiate(cfg.test)

    checkpoint = torch.load(cfg.weights, map_location='cpu')
    if cfg.num_gpus > 1:
        model.module.load_state_dict(checkpoint['model_state'])
    else:
        model.load_state_dict(checkpoint['model_state'])

    print("Testing...")
    err = test_net(model=model,
                   test_dataset=test_dataset,
                   batch_size=cfg.test.batch_size,
                   loader_params=cfg.data_loader,
                   num_gpus=cfg.num_gpus)

    test_corrupt_net(model=model,
                     corrupt_cfg=cfg.corrupt,
                     batch_size=cfg.corrupt.batch_size,
                     loader_params=cfg.data_loader,
                     aug_string=cfg.corrupt.aug_string,
                     clean_err=err,
                     mCE_denom=cfg.corrupt.mCE_baseline_file,
                     num_gpus=cfg.num_gpus,
                     log_name='train_imagenet.log')
def train(cfg, is_leader=True):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    cur_device = torch.cuda.current_device()
    model = instantiate(cfg.model).cuda(device=cur_device)
    if cfg.num_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(
            module=model, device_ids=[cur_device], output_device=cur_device)
    optimizer = instantiate(cfg.optim, model.parameters())
    if cfg.optim.max_epoch > 0:
        train_dataset = instantiate(cfg.train)
    else:
        train_dataset = None
    test_dataset = instantiate(cfg.test)
    lr_policy = instantiate(cfg.optim.lr_policy)
    with omegaconf.open_dict(cfg):
        feature_extractor = instantiate(cfg.ft,
                                        num_gpus=cfg.num_gpus,
                                        is_leader=is_leader)
    feature_extractor.train()

    train_net(model=model,
              optimizer=optimizer,
              train_dataset=train_dataset,
              batch_size=cfg.train.batch_size,
              max_epoch=cfg.optim.max_epoch,
              loader_params=cfg.data_loader,
              lr_policy=lr_policy,
              save_period=cfg.train.checkpoint_period,
              weights=cfg.train.weights,
              num_gpus=cfg.num_gpus,
              is_leader=is_leader)

    err = test_net(model=model,
                   test_dataset=test_dataset,
                   batch_size=cfg.test.batch_size,
                   loader_params=cfg.data_loader,
                   output_name='test_epoch',
                   num_gpus=cfg.num_gpus)

    if os.path.exists(cfg.feature_file):
        feature_dict = {k: v for k, v in np.load(cfg.feature_file).items()}
    else:
        feature_dict = {}
    indices = np.load(cfg.ft_corrupt.indices_file)
    for aug in cfg.aug_string.split("--"):
        if len(aug.split("-")) > 1:
            #log.info("Severity provided in corrupt.aug_string will be weighted by given severity.")
            sev = aug.split("-")[1]
            if len(sev.split("_")) > 1:
                low = float(sev.split("_")[0])
                high = float(sev.split("_")[1])
            else:
                low = 0.0
                high = float(sev)

            sev_factor = (high - low) * cfg.severity / 10 + low
        else:
            sev_factor = cfg.severity
        aug = aug.split("-")[0]
        aug_string = "{}-{}".format(aug, sev_factor)
        if aug_string in feature_dict:
            continue
        with omegaconf.open_dict(cfg.corrupt):
            corrupt_dataset = instantiate(cfg.corrupt, aug_string=aug_string)
        err = test_net(model=model,
                       test_dataset=corrupt_dataset,
                       batch_size=cfg.corrupt.batch_size,
                       loader_params=cfg.data_loader,
                       output_name=aug_string,
                       num_gpus=cfg.num_gpus)
        with omegaconf.open_dict(cfg.ft_corrupt):
            ft_corrupt_dataset = instantiate(cfg.ft_corrupt,
                                             aug_string=aug_string)
        if cfg.ft_corrupt.params.num_transforms is not None:
            ft_corrupt_dataset = ft_corrupt_dataset.serialize(indices)
        else:
            ft_corrupt_dataset = torch.utils.data.Subset(
                ft_corrupt_dataset, indices)

        feature = extract_features(feature_extractor=feature_extractor,
                                   dataset=ft_corrupt_dataset,
                                   batch_size=cfg.ft_corrupt.batch_size,
                                   loader_params=cfg.data_loader,
                                   average=True,
                                   num_gpus=cfg.num_gpus)
        feature_dict[aug_string] = feature
        if is_leader:
            np.savez(cfg.feature_file, **feature_dict)
def train(cfg):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    model = instantiate(cfg.model).cuda()
    optimizer = instantiate(cfg.optim, model.parameters())
    lr_policy = instantiate(cfg.optim.lr_policy)
    if cfg.transform_file and os.path.exists(cfg.transform_file):
        log.info("Transforms found, loading feature extractor is unnecessary.  Skipping.")
    else:
        feature_extractor = instantiate(cfg.ft)
        feature_extractor.train()
    
    if cfg.transform_file and os.path.exists(cfg.transform_file):
        log.info("Transforms found, feature extraction is unnecessary.  Skipping.")
    elif cfg.aug_feature_file and os.path.exists(cfg.aug_feature_file):
        log.info("Found feature file.  Loading from {}".format(cfg.aug_feature_file))
        data = np.load(cfg.aug_feature_file)
        augmentation_features = data['features']
        indices = data['indices']
        transforms = data['transforms']
    else:
        ft_augmentation_dataset = instantiate(cfg.ft_augmentation)
        transforms = ft_augmentation_dataset.transform_list
        indices = np.random.choice(np.arange(len(ft_augmentation_dataset)), size=cfg.num_images, replace=False)
        ft_augmentation_dataset = ft_augmentation_dataset.serialize(indices)
        augmentation_features = extract_features(feature_extractor,
                                                 ft_augmentation_dataset,
                                                 cfg.ft_augmentation.batch_size,
                                                 cfg.data_loader,
                                                 average=True,
                                                 average_num=len(indices))
        if cfg.aug_feature_file:
            np.savez(cfg.aug_feature_file, 
                    features=augmentation_features, 
                    indices=indices, 
                    transforms=transforms)

    if cfg.transform_file and os.path.exists(cfg.transform_file):
        log.info("Found transform file.  Loading from {}.".format(cfg.transform_file))
        sorted_transforms = np.load(cfg.transform_file)
    else:    
        aug_strings = cfg.ft_corrupt.aug_string.split("--")
        distances = np.zeros((len(augmentation_features), len(aug_strings)))
        for i, aug in enumerate(aug_strings):
            with omegaconf.open_dict(cfg):
                ft_corrupt_dataset = instantiate(cfg.ft_corrupt, aug_string=aug)       
            if cfg.num_corrupt_images and i==0:
                indices = np.random.choice(np.arange(len(ft_corrupt_dataset)), size=cfg.num_corrupt_images, replace=False)
            ft_corrupt_dataset = ft_corrupt_dataset.serialize(indices)
            corruption_features = extract_features(feature_extractor,
                                                   ft_corrupt_dataset,
                                                   cfg.ft_corrupt.batch_size,
                                                   cfg.data_loader,
                                                   average=True)
            
            corruption_features = corruption_features.reshape(1, -1)
            dists = np.linalg.norm(augmentation_features - corruption_features, axis=-1)

            distances[:,i] = dists

        sorted_dist_args = individual_sort(distances)
        sorted_transforms = transforms[sorted_dist_args]
        if cfg.transform_file:
            np.save(cfg.transform_file, sorted_transforms)

    train_dataset = instantiate(cfg.train)
    if cfg.selection_type == 'closest':
        train_dataset.transform_list = sorted_transforms[cfg.offset:cfg.offset+cfg.num_transforms]
    elif cfg.selection_type == 'farthest':
        train_dataset.transform_list = sorted_transforms[-cfg.offset-cfg.num_transforms:-cfg.offset]\
                if cfg.offset != 0 else sorted_transforms[-cfg.num_transforms:]
    else:
        train_dataset.transform_list = sorted_transforms[np.random.choice(np.arange(len(sorted_transforms)), size=cfg.num_transforms, replace=False)]

    test_dataset = instantiate(cfg.test)

    train_net(model=model,
            optimizer=optimizer,
            train_dataset=train_dataset,
            batch_size=cfg.train.batch_size,
            max_epoch=cfg.optim.max_epoch,
            loader_params=cfg.data_loader,
            lr_policy=lr_policy,
            save_period=cfg.train.checkpoint_period,
            weights=cfg.train.weights
            )

    err = test_net(model=model,
            test_dataset=test_dataset,
            batch_size=cfg.test.batch_size,
            loader_params=cfg.data_loader,
            output_name='test_epoch')

    test_corrupt_net(model=model,
            corrupt_cfg=cfg.corrupt,
            batch_size=cfg.corrupt.batch_size,
            loader_params=cfg.data_loader,
            aug_string=cfg.corrupt.aug_string,
            clean_err=err,
            mCE_denom=cfg.corrupt.mCE_baseline_file)
def train(cfg):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    model = instantiate(cfg.model).cuda()
    optimizer = instantiate(cfg.optim, model.parameters())
    train_dataset = instantiate(cfg.train)
    test_dataset = instantiate(cfg.test)
    lr_policy = instantiate(cfg.optim.lr_policy)
    feature_extractor = instantiate(cfg.ft)
    feature_extractor.train()

    if cfg.aug_feature_file and os.path.exists(cfg.aug_feature_file):
        log.info("Found feature file.  Loading from {}".format(
            cfg.aug_feature_file))
        data = np.load(cfg.aug_feature_file)
        augmentation_features = data['features']
        indices = data['indices']
    else:
        ft_augmentation_dataset = instantiate(cfg.ft_augmentation)
        indices = np.random.choice(np.arange(len(ft_augmentation_dataset)),
                                   size=cfg.num_images,
                                   replace=False)
        ft_augmentation_dataset = ft_augmentation_dataset.serialize(indices)
        augmentation_features = extract_features(
            feature_extractor,
            ft_augmentation_dataset,
            cfg.ft_augmentation.batch_size,
            cfg.data_loader,
            average=True,
            average_num=len(indices))
        #nf, lf = augmentation_features.shape
        #augmentation_features = np.mean(augmentation_features.reshape(len(indices), nf//len(indices), lf), axis=0)
        if cfg.aug_feature_file:
            np.savez(cfg.aug_feature_file,
                     features=augmentation_features,
                     indices=indices)

    aug_strings = cfg.ft_corrupt.aug_string.split("--")
    for aug in aug_strings:
        with omegaconf.open_dict(cfg):
            ft_corrupt_dataset = instantiate(cfg.ft_corrupt, aug_string=aug)
        ft_corrupt_dataset = ft_corrupt_dataset.serialize(indices)
        corruption_features = extract_features(feature_extractor,
                                               ft_corrupt_dataset,
                                               cfg.ft_corrupt.batch_size,
                                               cfg.data_loader,
                                               average=True,
                                               average_num=len(indices))
        nf, lf = corruption_features.shape
        #corruption_features = np.mean(corruption_features.reshape(len(indices), nf//len(indices), lf), axis=0)

        augmentation_features = augmentation_features.reshape(-1, 1, lf)
        corruption_features = corruption_features.reshape(1, -1, lf)
        mean_aug = np.mean(augmentation_features.reshape(-1, lf), axis=0)
        mean_corr = np.mean(corruption_features.reshape(-1, lf), axis=0)
        mmd = np.linalg.norm(mean_aug - mean_corr, axis=0)
        msd = np.min(np.linalg.norm(augmentation_features.reshape(-1, lf) -
                                    mean_corr.reshape(1, lf),
                                    axis=1),
                     axis=0)

        stats = {
            "_type": aug,
            "mmd": str(mmd),
            "msd": str(msd),
        }
        lu.log_json_stats(stats)

    train_net(model=model,
              optimizer=optimizer,
              train_dataset=train_dataset,
              batch_size=cfg.train.batch_size,
              max_epoch=cfg.optim.max_epoch,
              loader_params=cfg.data_loader,
              lr_policy=lr_policy,
              save_period=cfg.train.checkpoint_period,
              weights=cfg.train.weights)

    err = test_net(model=model,
                   test_dataset=test_dataset,
                   batch_size=cfg.test.batch_size,
                   loader_params=cfg.data_loader,
                   output_name='test_epoch')

    test_corrupt_net(model=model,
                     corrupt_cfg=cfg.corrupt,
                     batch_size=cfg.corrupt.batch_size,
                     loader_params=cfg.data_loader,
                     aug_string=cfg.corrupt.aug_string,
                     clean_err=err,
                     mCE_denom=cfg.corrupt.mCE_baseline_file)