Esempio n. 1
0
def main():
    args = parser.parse_args()
    if not os.path.isdir('CMDs'):
        os.mkdir('CMDs')
    with open('CMDs/step_train_dpn.cmd', 'a') as f:
        f.write(' '.join(sys.argv) + '\n')
        f.write('--------------------------------\n')

    model_dir = Path(args.model_dir)
    checkpoint_path = args.checkpoint_path
    if checkpoint_path is None:
        checkpoint_path = model_dir / 'model'
    # Check that we are training on a sensible GPU
    assert max(args.gpu) <= torch.cuda.device_count() - 1

    device = select_gpu(args.gpu)
    # Load up the model
    ckpt = torch.load(model_dir / 'model/model.tar', map_location=device)
    model = ModelFactory.model_from_checkpoint(ckpt)
    if len(args.gpu) > 1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
        print('Using Multi-GPU training.')
    model.to(device)

    if args.normalize:
        mean = DATASET_DICT[args.id_dataset].mean
        std = DATASET_DICT[args.id_dataset].std
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)

    # Load the in-domain training and validation data
    train_dataset = DATASET_DICT[args.id_dataset](root=args.data_path,
                                                  transform=construct_transforms(
                                                      n_in=ckpt['n_in'],
                                                      mode='train',
                                                      mean=mean,
                                                      std=std,
                                                      augment=args.augment,
                                                      rotation=args.rotate,
                                                      jitter=args.jitter),
                                                  target_transform=None,
                                                  download=True,
                                                  split='train')

    val_dataset = DATASET_DICT[args.id_dataset](root=args.data_path,
                                                transform=construct_transforms(
                                                    n_in=ckpt['n_in'],
                                                    mean=mean,
                                                    std=std,
                                                    mode='eval',
                                                    rotation=args.rotate,
                                                    jitter=args.jitter),
                                                target_transform=None,
                                                download=True,
                                                split='val')

    # Load the out-of-domain training dataset
    ood_dataset = DATASET_DICT[args.ood_dataset](root=args.data_path,
                                                 transform=construct_transforms(
                                                     n_in=ckpt['n_in'],
                                                     mean=mean,
                                                     std=std,
                                                     mode='ood'),
                                                 target_transform=None,
                                                 download=True,
                                                 split='train')
    ood_val_dataset = DATASET_DICT[args.ood_dataset](root=args.data_path,
                                                     transform=construct_transforms(
                                                         n_in=ckpt['n_in'],
                                                         mean=mean,
                                                         std=std,
                                                         mode='eval'),
                                                     target_transform=None,
                                                     download=True,
                                                     split='val')

    # Combine ID and OOD training datasets into a single dataset for
    # training (necessary for DataParallel training)
    assert len(val_dataset) == len(ood_val_dataset)

    # Even out dataset lengths.
    id_ratio = 1.0
    if len(train_dataset) < len(ood_dataset):
        id_ratio = np.ceil(float(len(ood_dataset)) / float(len(train_dataset)))
        assert id_ratio.is_integer()
        dataset_list = [train_dataset, ] * (int(id_ratio))
        train_dataset = data.ConcatDataset(dataset_list)

    if len(train_dataset) > len(ood_dataset):
        ratio = np.ceil(float(len(train_dataset)) / float(len(ood_dataset)))
        assert ratio.is_integer()
        dataset_list = [ood_dataset, ] * int(ratio)
        ood_dataset = data.ConcatDataset(dataset_list)

        if len(ood_dataset) > len(train_dataset):
            ood_dataset = data.Subset(ood_dataset, np.arange(0, len(train_dataset)))

    assert len(train_dataset) == len(ood_dataset)
    print(f"Validation dataset length: {len(val_dataset)}")
    print(f"Train dataset length: {len(train_dataset)}")

    # Set up training and test criteria
    id_criterion = DirichletKLLoss(target_concentration=args.target_concentration,
                                   concentration=args.concentration,
                                   reverse=args.reverse_KL)

    ood_criterion = DirichletKLLoss(target_concentration=0.0,
                                    concentration=args.concentration,
                                    reverse=args.reverse_KL)

    criterion = PriorNetMixedLoss([id_criterion, ood_criterion], mixing_params=[1.0, args.gamma])

    # Select optimizer and optimizer params
    optimizer, optimizer_params = choose_optimizer(args.optimizer,
                                                   args.lr,
                                                   args.weight_decay)

    # Setup model trainer and train model
    lrc = [int(lrc / id_ratio) for lrc in args.lrc]
    trainer = TrainerWithOOD(model=model,
                             criterion=criterion,
                             id_criterion=id_criterion,
                             ood_criterion=ood_criterion,
                             test_criterion=criterion,
                             ood_dataset=ood_dataset,
                             test_ood_dataset=ood_val_dataset,
                             train_dataset=train_dataset,
                             test_dataset=val_dataset,
                             optimizer=optimizer,
                             device=device,
                             checkpoint_path=checkpoint_path,
                             scheduler=optim.lr_scheduler.MultiStepLR,
                             optimizer_params=optimizer_params,
                             scheduler_params={'milestones': lrc, 'gamma': args.lr_decay},
                             batch_size=args.batch_size,
                             clip_norm=args.clip_norm)
    if args.resume:
        try:
            trainer.load_checkpoint(True, True, map_location=device)
        except:
            print('No checkpoint found, training from empty model.')
            pass
    trainer.train(int(args.n_epochs / id_ratio), resume=args.resume)

    # Save final model
    if len(args.gpu) > 1 and torch.cuda.device_count() > 1:
        model = model.module
    ModelFactory.checkpoint_model(path=model_dir / 'model/model.tar',
                                  model=model,
                                  arch=ckpt['arch'],
                                  dropout_rate=ckpt['dropout_rate'],
                                  n_channels=ckpt['n_channels'],
                                  num_classes=ckpt['num_classes'],
                                  small_inputs=ckpt['small_inputs'],
                                  n_in=ckpt['n_in'])
def main():
    args = parser.parse_args()
    if not os.path.isdir('CMDs'):
        os.mkdir('CMDs')
    with open('CMDs/step_train_distillation.cmd', 'a') as f:
        f.write(' '.join(sys.argv) + '\n')
        f.write('--------------------------------\n')

    model_dir = Path(args.model_dir)

    # Check that we are training on a sensible GPU
    assert max(args.gpu) <= torch.cuda.device_count() - 1

    device = select_gpu(args.gpu)
    # Load up the model
    ckpt = torch.load(model_dir / 'model/model.tar', map_location=device)
    model = ModelFactory.model_from_checkpoint(ckpt)
    if len(args.gpu) > 1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
        print('Using Multi-GPU training.')
    model.to(device)

    # Load the in-domain training and validation data
    train_dataset_class = DATASET_DICT[args.dataset]
    train_dataset_parameters = {
        'root':
        args.data_path,
        'transform':
        construct_transforms(n_in=ckpt['n_in'],
                             mode='train',
                             mean=DATASET_DICT[args.dataset].mean,
                             std=DATASET_DICT[args.dataset].std,
                             augment=args.augment,
                             rotation=args.rotate,
                             jitter=args.jitter),
        'target_transform':
        None,
        'download':
        True,
        'split':
        'train'
    }
    train_dataset = EnsembleDataset(
        dataset=train_dataset_class,
        dataset_parameters=train_dataset_parameters,
        ensemble_path=args.ensemble_path,
        model_dirs=args.model,
        n_models=args.n_models,
        folder='train')

    val_dataset_class = DATASET_DICT[args.dataset]
    val_dataset_parameters = {
        'root':
        args.data_path,
        'transform':
        construct_transforms(n_in=ckpt['n_in'],
                             mean=DATASET_DICT[args.dataset].mean,
                             std=DATASET_DICT[args.dataset].std,
                             mode='eval'),
        'target_transform':
        None,
        'download':
        True,
        'split':
        'val'
    }
    val_dataset = EnsembleDataset(dataset=val_dataset_class,
                                  dataset_parameters=val_dataset_parameters,
                                  ensemble_path=args.ensemble_path,
                                  model_dirs=args.model,
                                  n_models=args.n_models,
                                  folder='eval')

    if args.ood:
        assert args.ood_folder is not None
        ood_dataset_class = DATASET_DICT[args.ood_dataset]
        ood_dataset_parameters = {
            'root':
            args.data_path,
            'transform':
            construct_transforms(n_in=ckpt['n_in'],
                                 mean=DATASET_DICT[args.dataset].mean,
                                 std=DATASET_DICT[args.dataset].std,
                                 mode='train'),
            'target_transform':
            None,
            'download':
            True,
            'split':
            'train'
        }
        ood_dataset = EnsembleDataset(
            dataset=ood_dataset_class,
            dataset_parameters=ood_dataset_parameters,
            ensemble_path=args.ensemble_path,
            model_dirs=args.model,
            n_models=args.n_models,
            folder=args.ood_folder)

        train_dataset = data.ConcatDataset([train_dataset, ood_dataset])

    # Set up training and test criteria
    test_criterion = torch.nn.CrossEntropyLoss()
    if args.endd:
        train_criterion = DirichletEnDDLoss()
    elif args.endd_entemp:
        train_criterion = DirichletEnDDEnTempLoss()
    elif args.endd_dirtemp:
        train_criterion = DirichletEnDDDirTempLoss()
    elif args.endd_revtemp:
        train_criterion = DirichletEnDDRevLoss()
    else:
        train_criterion = EnDLoss()

    # Select optimizer and optimizer params
    optimizer, optimizer_params = choose_optimizer(args.optimizer, args.lr,
                                                   args.weight_decay)

    # Setup model trainer and train model
    trainer = TrainerDistillation(model=model,
                                  criterion=train_criterion,
                                  test_criterion=test_criterion,
                                  train_dataset=train_dataset,
                                  test_dataset=val_dataset,
                                  optimizer=optimizer,
                                  device=device,
                                  checkpoint_path=model_dir / 'model',
                                  scheduler=optim.lr_scheduler.MultiStepLR,
                                  temp_scheduler=LRTempScheduler,
                                  optimizer_params=optimizer_params,
                                  scheduler_params={
                                      'milestones': args.lrc,
                                      'gamma': args.lr_decay
                                  },
                                  temp_scheduler_params={
                                      'init_temp': args.temperature,
                                      'min_temp': args.min_temperature,
                                      'decay_epoch': args.tdecay_epoch,
                                      'decay_length': args.tdecay_length
                                  },
                                  batch_size=args.batch_size,
                                  clip_norm=args.clip_norm)
    if args.resume:
        trainer.load_checkpoint(model_dir / 'model/checkpoint.tar',
                                True,
                                True,
                                map_location=device)
    trainer.train(args.n_epochs, resume=args.resume)

    # Save final model
    if len(args.gpu) > 1 and torch.cuda.device_count() > 1:
        model = model.module
    ModelFactory.checkpoint_model(path=model_dir / 'model/model.tar',
                                  model=model,
                                  arch=ckpt['arch'],
                                  n_channels=ckpt['n_channels'],
                                  num_classes=ckpt['num_classes'],
                                  small_inputs=ckpt['small_inputs'],
                                  n_in=ckpt['n_in'])
Esempio n. 3
0
def main():
    args = parser.parse_args()
    if not os.path.isdir('CMDs'):
        os.mkdir('CMDs')
    with open('CMDs/step_train_dnn.cmd', 'a') as f:
        f.write(' '.join(sys.argv) + '\n')
        f.write('--------------------------------\n')

    model_dir = Path(args.model_dir)
    checkpoint_path = args.checkpoint_path
    if checkpoint_path is None:
        checkpoint_path = model_dir / 'model'
    # Load up the model

    assert max(args.gpu) <= torch.cuda.device_count() - 1

    device = select_gpu(args.gpu)
    # Load up the model
    ckpt = torch.load(model_dir / 'model/model.tar', map_location=device)
    model = ModelFactory.model_from_checkpoint(ckpt)
    if len(args.gpu) > 1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu)
        print('Using Multi-GPU training.')
    model.to(device)

    if args.normalize:
        mean = DATASET_DICT[args.dataset].mean
        std = DATASET_DICT[args.dataset].std
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)

    # Load the in-domain training and validation data
    train_dataset = DATASET_DICT[args.dataset](root=args.data_path,
                                               transform=construct_transforms(
                                                   n_in=ckpt['n_in'],
                                                   mode='train',
                                                   mean=mean,
                                                   std=std,
                                                   augment=args.augment,
                                                   rotation=args.rotate,
                                                   jitter=args.jitter),
                                               target_transform=None,
                                               download=True,
                                               split='train')

    val_dataset = DATASET_DICT[args.dataset](root=args.data_path,
                                             transform=construct_transforms(
                                                 n_in=ckpt['n_in'],
                                                 mean=mean,
                                                 std=std,
                                                 mode='eval'),
                                             target_transform=None,
                                             download=True,
                                             split='val')

    # Check that we are training on a sensible GPU

    # Set up training and test criteria
    criterion = torch.nn.CrossEntropyLoss()

    # Select optimizer and optimizer params
    optimizer, optimizer_params = choose_optimizer(args.optimizer, args.lr,
                                                   args.weight_decay)

    # Setup model trainer and train model
    trainer = Trainer(model=model,
                      criterion=criterion,
                      test_criterion=criterion,
                      train_dataset=train_dataset,
                      test_dataset=val_dataset,
                      optimizer=optimizer,
                      device=device,
                      checkpoint_path=checkpoint_path,
                      scheduler=optim.lr_scheduler.MultiStepLR,
                      optimizer_params=optimizer_params,
                      scheduler_params={
                          'milestones': args.lrc,
                          'gamma': args.lr_decay
                      },
                      batch_size=args.batch_size,
                      clip_norm=args.clip_norm)
    if args.resume:
        try:
            trainer.load_checkpoint(True, True, map_location=device)
        except:
            print('No checkpoint found, training from empty model.')
            pass
    trainer.train(args.n_epochs, resume=args.resume)

    # Save final model
    if len(args.gpu) > 1 and torch.cuda.device_count() > 1:
        model = model.module
    ModelFactory.checkpoint_model(path=model_dir / 'model/model.tar',
                                  model=model,
                                  arch=ckpt['arch'],
                                  dropout_rate=ckpt['dropout_rate'],
                                  n_channels=ckpt['n_channels'],
                                  num_classes=ckpt['num_classes'],
                                  small_inputs=ckpt['small_inputs'],
                                  n_in=ckpt['n_in'])