Beispiel #1
0
 def _get_transform(self, metadata):
     input_metadata = metadata['input']
     size = input_metadata['size']
     space = input_metadata['space']
     drange = input_metadata['range']
     normalize = input_metadata['normalize']
     if size == None:
         size = 224
     if isinstance(size, (list, tuple)):
         size = size[-1]
     transform = [
         sT.Resize(size),
         sT.CenterCrop(size),
     ]
     if space == 'bgr':
         transform.append(sT.FlipChannels())
     if list(drange) == [0, 1]:
         transform.append(sT.ToTensor())
     elif list(drange) == [0, 255]:
         transform.append(sT.ToTensor(normalize=False, dtype=torch.float))
     else:
         raise NotImplementedError
     if normalize is not None:
         transform.append(
             sT.Normalize(mean=normalize['mean'], std=normalize['std']))
     return sT.Compose(transform)
Beispiel #2
0
def main():
    car_train_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                                 split='train')
    car_val_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                               split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                      split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                    split='test')
    car_teacher = vision.models.classification.resnet18(num_classes=196,
                                                        pretrained=False)
    aircraft_teacher = vision.models.classification.resnet18(num_classes=102,
                                                             pretrained=False)
    student = vision.models.classification.resnet18(num_classes=196 + 102,
                                                    pretrained=False)
    car_teacher.load_state_dict(torch.load(args.car_ckpt))
    aircraft_teacher.load_state_dict(torch.load(args.aircraft_ckpt))
    train_transform = sT.Compose([
        sT.RandomResizedCrop(224),
        sT.RandomHorizontalFlip(),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_transform = sT.Compose([
        sT.Resize(256),
        sT.CenterCrop(224),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    car_train_dst.transform = aircraft_train_dst.transform = train_transform
    car_val_dst.transform = aircraft_val_dst.transform = val_transform

    car_metric = metrics.MetricCompose(
        metric_dict={
            'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))
        })
    aircraft_metric = metrics.MetricCompose(metric_dict={
        'aircraft_acc':
        metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:], t))
    })

    train_dst = torch.utils.data.ConcatDataset(
        [car_train_dst, aircraft_train_dst])
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=4)
    car_loader = torch.utils.data.DataLoader(car_val_dst,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=4)
    aircraft_loader = torch.utils.data.DataLoader(aircraft_val_dst,
                                                  batch_size=32,
                                                  shuffle=False,
                                                  num_workers=4)

    car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric)
    aircraft_evaluator = engine.evaluator.BasicEvaluator(
        aircraft_loader, aircraft_metric)

    if args.ckpt is not None:
        student.load_state_dict(torch.load(args.ckpt))
        print("Load student model from %s" % args.ckpt)
    if args.test_only:
        results_car = car_evaluator.eval(student)
        results_aircraft = aircraft_evaluator.eval(student)
        print("Stanford Cars: %s" % (results_car))
        print("FGVC Aircraft: %s" % (results_aircraft))
        return

    TOTAL_ITERS = len(train_loader) * 100
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(student.parameters(),
                             lr=args.lr,
                             weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)
    trainer = amalgamation.CommonFeatureAmalgamator(
        logger=utils.logger.get_logger('cfl'),
        tb_writer=SummaryWriter(log_dir='run/cfl-%s' %
                                (time.asctime().replace(' ', '_'))))

    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP(every=10),
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd',
                                                 'loss_amal', 'loss_recons',
                                                 'lr')))
    trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                         callbacks=[
                             callbacks.EvalAndCkpt(model=student,
                                                   evaluator=car_evaluator,
                                                   metric_name='car_acc',
                                                   ckpt_prefix='cfl_car'),
                             callbacks.EvalAndCkpt(
                                 model=student,
                                 evaluator=aircraft_evaluator,
                                 metric_name='aircraft_acc',
                                 ckpt_prefix='cfl_aircraft'),
                         ])
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    layer_groups = [(student.fc, car_teacher.fc, aircraft_teacher.fc)]
    layer_channels = [(512, 512, 512)]

    trainer.setup(student=student,
                  teachers=[car_teacher, aircraft_teacher],
                  layer_groups=layer_groups,
                  layer_channels=layer_channels,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device,
                  on_layer_input=True,
                  weights=[1., 10., 10.])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
Beispiel #3
0
def main():
    # Pytorch Part
    if args.dataset == 'stanford_dogs':
        num_classes = 120
        train_dst = vision.datasets.StanfordDogs('data/StanfordDogs',
                                                 split='train')
        val_dst = vision.datasets.StanfordDogs('data/StanfordDogs',
                                               split='test')
    elif args.dataset == 'cub200':
        num_classes = 200
        train_dst = vision.datasets.CUB200('data/CUB200', split='train')
        val_dst = vision.datasets.CUB200('data/CUB200', split='test')
    elif args.dataset == 'fgvc_aircraft':
        num_classes = 102
        train_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/',
                                                 split='trainval')
        val_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/',
                                               split='test')
    elif args.dataset == 'stanford_cars':
        num_classes = 196
        train_dst = vision.datasets.StanfordCars('data/StanfordCars/',
                                                 split='train')
        val_dst = vision.datasets.StanfordCars('data/StanfordCars/',
                                               split='test')
    else:
        raise NotImplementedError

    model = vision.models.classification.resnet18(num_classes=num_classes,
                                                  pretrained=args.pretrained)
    train_dst.transform = sT.Compose([
        sT.RandomResizedCrop(224),
        sT.RandomHorizontalFlip(),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_dst.transform = sT.Compose([
        sT.Resize(256),
        sT.CenterCrop(224),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=32,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * args.epochs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.SGD(model.parameters(),
                            lr=args.lr,
                            momentum=0.9,
                            weight_decay=5e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    # KAE Part
    # Predefined task & metrics
    task = kamal.tasks.StandardTask.classification()
    metric = kamal.tasks.StandardMetrics.classification()
    # Evaluator and Trainer
    evaluator = engine.evaluator.BasicEvaluator(val_loader,
                                                metric=metric,
                                                progress=True)
    trainer = engine.trainer.BasicTrainer(
        logger=kamal.utils.logger.get_logger(args.dataset),
        tb_writer=SummaryWriter(
            log_dir='run/%s-%s' %
            (args.dataset, time.asctime().replace(' ', '_'))))
    # setup trainer
    trainer.setup(model=model,
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    ckpt_callback = trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                                         callbacks=callbacks.EvalAndCkpt(
                                             model=model,
                                             evaluator=evaluator,
                                             metric_name='acc',
                                             ckpt_prefix=args.dataset))
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
    ckpt_callback.callback.final_ckpt(ckpt_dir='pretrained', add_md5=True)
def main():
    car_train_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='train')
    car_val_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='test')

    dog_train_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='train')
    dog_val_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='test')
    cub_train_dst = vision.datasets.CUB200( '../data/CUB200', split='train')
    cub_val_dst = vision.datasets.CUB200( '../data/CUB200', split='test')

    #car_teacher = vision.models.classification.resnet18( num_classes=196, pretrained=False )
    #aircraft_teacher = vision.models.classification.resnet18( num_classes=102, pretrained=False )
    #dog_teacher = vision.models.classification.resnet18( num_classes=120, pretrained=False )
    #cub_teacher = vision.models.classification.resnet18( num_classes=200, pretrained=False )
    student = vision.models.classification.resnet18( num_classes=196+102+120+200, pretrained=False )

    #car_teacher.load_state_dict( torch.load( args.car_ckpt ) )
    #aircraft_teacher.load_state_dict( torch.load( args.aircraft_ckpt ) )
    #dog_teacher.load_state_dict( torch.load( args.dog_ckpt ) )
    #cub_teacher.load_state_dict( torch.load( args.cub_ckpt ) )

    train_transform = sT.Compose( [
                            sT.RandomResizedCrop(224),
                            sT.RandomHorizontalFlip(),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    val_transform = sT.Compose( [
                            sT.Resize(256),
                            sT.CenterCrop( 224 ),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    
    cub_train_dst.transform = dog_train_dst.transform = car_train_dst.transform = aircraft_train_dst.transform = train_transform
    cub_val_dst.transform = dog_val_dst.transform = car_val_dst.transform = aircraft_val_dst.transform = val_transform
    aircraft_train_dst.target_transform = lambda t: t+196
    dog_train_dst.target_transform = lambda t: t+196+102
    cub_train_dst.target_transform = lambda t: t+196+102+120

    car_metric =        metrics.MetricCompose(metric_dict={ 'car_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, :196],t) ) })
    aircraft_metric =   metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+102],t) ) })
    dog_metric =        metrics.MetricCompose(metric_dict={ 'dog_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102:196+102+120],t) ) })
    cub_metric =        metrics.MetricCompose(metric_dict={ 'cub_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102+120:196+102+120+200],t) ) })

    train_dst = torch.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst, dog_train_dst, cub_train_dst] )
    train_loader = torch.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=4 )
    car_loader = torch.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    aircraft_loader = torch.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    dog_loader = torch.utils.data.DataLoader( dog_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    cub_loader = torch.utils.data.DataLoader( cub_val_dst, batch_size=32, shuffle=False, num_workers=2 )

    car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric )
    aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric )
    dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric )
    cub_evaluator = engine.evaluator.BasicEvaluator( cub_loader, cub_metric )
    
    if args.ckpt is not None:
        student.load_state_dict( torch.load( args.ckpt ) )
        print("Load student model from %s"%args.ckpt)

    if args.test_only:
        results_car = car_evaluator.eval( student )
        results_aircraft = aircraft_evaluator.eval( student )
        results_dog = dog_evaluator.eval( student )
        results_cub = cub_evaluator.eval( student )
        print("Stanford Cars: %s"%( results_car ))
        print("FGVC Aircraft: %s"%( results_aircraft ))
        print("Stanford Dogs: %s"%( results_dog ))
        print("CUB200: %s"%( results_cub ))
        return 

    TOTAL_ITERS=len(train_loader) * 100
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    optim = torch.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
    task = tasks.StandardTask.classification()
    trainer = engine.trainer.BasicTrainer( 
        logger=utils.logger.get_logger('scratch-4'), 
        tb_writer=SummaryWriter( log_dir='run/scratch-4-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr')))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[
            callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='cfl_car'),
            callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='cfl_aircraft'),
            callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='cfl_dog'),
            callbacks.EvalAndCkpt(model=student, evaluator=cub_evaluator, metric_name='cub_acc', ckpt_prefix='cfl_cub'),
        ] )
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    trainer.setup( model=student,
                   task=task,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device )
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)