def get_dataloader():
    train_dst = vision.datasets.CamVid(
        '../data/CamVid11',
        split='trainval',
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240)),
            sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240)),
            sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip()),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]), None)
        ]))
    test_dst = vision.datasets.CamVid(
        '../data/CamVid11',
        split='test',
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240)),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]), None)
        ]))
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_dst,
                                              batch_size=16,
                                              shuffle=False,
                                              num_workers=4)
    return train_loader, test_loader
示例#2
0
 def _get_transform(self, metadata):
     input_metadata = metadata['input']
     size = input_metadata['size']
     space = input_metadata['space']
     drange = input_metadata['range']
     normalize = input_metadata['normalize']
     if size == None:
         size = 224
     if isinstance(size, (list, tuple)):
         size = size[-1]
     transform = [
         sT.Resize(size),
         sT.CenterCrop(size),
     ]
     if space == 'bgr':
         transform.append(sT.FlipChannels())
     if list(drange) == [0, 1]:
         transform.append(sT.ToTensor())
     elif list(drange) == [0, 255]:
         transform.append(sT.ToTensor(normalize=False, dtype=torch.float))
     else:
         raise NotImplementedError
     if normalize is not None:
         transform.append(
             sT.Normalize(mean=normalize['mean'], std=normalize['std']))
     return sT.Compose(transform)
def main():
    # PyTorch Part
    num_classes = 11
    model = vision.models.segmentation.deeplabv3_resnet50(num_classes=num_classes, pretrained_backbone=True)
    train_dst = vision.datasets.CamVid( 
        'data/CamVid11', split='trainval', transforms=sT.Compose([
            sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)),
            sT.Sync(  sT.RandomRotation(5),  sT.RandomRotation(5)),
            sT.Multi( sT.ColorJitter(0.2, 0.2, 0.2), None),
            sT.Sync(  sT.RandomCrop(240),  sT.RandomCrop(240)),
            sT.Sync(  sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip() ),
            sT.Multi( sT.ToTensor(), sT.ToTensor( normalize=False, dtype=torch.long) ),
            sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) )
        ]) )
    val_dst = vision.datasets.CamVid( 
        'data/CamVid11', split='test', transforms=sT.Compose([
            sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)),
            sT.Multi( sT.ToTensor(),  sT.ToTensor( normalize=False, dtype=torch.long ) ),
            sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) )
        ]) )
    
    train_loader = torch.utils.data.DataLoader( train_dst, batch_size=16, shuffle=True, num_workers=4 )
    val_loader = torch.utils.data.DataLoader( val_dst, batch_size=16, num_workers=4 )
    TOTAL_ITERS=len(train_loader) * 200
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    optim = torch.optim.SGD( model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 )
    sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )

    # KAE Part
    metric = kamal.tasks.StandardMetrics.segmentation(num_classes=num_classes)
    evaluator = engine.evaluator.BasicEvaluator( dataloader=val_loader, metric=metric, progress=False )

    task = kamal.tasks.StandardTask.segmentation()
    trainer = engine.trainer.BasicTrainer( 
        logger=kamal.utils.logger.get_logger('camvid_seg_deeplab'), 
        tb_writer=SummaryWriter( log_dir='run/camvid_seg_deeplab-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    trainer.setup( model=model, 
                   task=task,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device )

    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[ 
            callbacks.EvalAndCkpt(model=model, evaluator=evaluator, metric_name='miou', ckpt_prefix='camvid_seg_deeplabv3_resnet50'),
            callbacks.VisualizeSegmentation(
                model=model,
                dataset=val_dst, 
                idx_list_or_num_vis=10,
                normalizer=kamal.utils.Normalizer( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], reverse=True),
            )])

    #import matplotlib.pyplot as plt
    #lr_finder = kamal.engine.lr_finder.LRFinder()
    #best_lr = lr_finder.find( optim, model, trainer, lr_range=[1e-8, 1.0], max_iter=100, smooth_momentum=0.9 )
    #fig = lr_finder.plot(polyfit=4)
    #plt.savefig('lr_finder_deeplab.png')
    #lr_finder.adjust_learning_rate(optim, best_lr)

    trainer.run( start_iter=0, max_iter=TOTAL_ITERS )
示例#4
0
def main():
    car_train_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                                 split='train')
    car_val_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                               split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                      split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                    split='test')
    car_teacher = vision.models.classification.resnet18(num_classes=196,
                                                        pretrained=False)
    aircraft_teacher = vision.models.classification.resnet18(num_classes=102,
                                                             pretrained=False)
    student = vision.models.classification.resnet18(num_classes=196 + 102,
                                                    pretrained=False)
    car_teacher.load_state_dict(torch.load(args.car_ckpt))
    aircraft_teacher.load_state_dict(torch.load(args.aircraft_ckpt))
    train_transform = sT.Compose([
        sT.RandomResizedCrop(224),
        sT.RandomHorizontalFlip(),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_transform = sT.Compose([
        sT.Resize(256),
        sT.CenterCrop(224),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    car_train_dst.transform = aircraft_train_dst.transform = train_transform
    car_val_dst.transform = aircraft_val_dst.transform = val_transform

    car_metric = metrics.MetricCompose(
        metric_dict={
            'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))
        })
    aircraft_metric = metrics.MetricCompose(metric_dict={
        'aircraft_acc':
        metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:], t))
    })

    train_dst = torch.utils.data.ConcatDataset(
        [car_train_dst, aircraft_train_dst])
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=4)
    car_loader = torch.utils.data.DataLoader(car_val_dst,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=4)
    aircraft_loader = torch.utils.data.DataLoader(aircraft_val_dst,
                                                  batch_size=32,
                                                  shuffle=False,
                                                  num_workers=4)

    car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric)
    aircraft_evaluator = engine.evaluator.BasicEvaluator(
        aircraft_loader, aircraft_metric)

    if args.ckpt is not None:
        student.load_state_dict(torch.load(args.ckpt))
        print("Load student model from %s" % args.ckpt)
    if args.test_only:
        results_car = car_evaluator.eval(student)
        results_aircraft = aircraft_evaluator.eval(student)
        print("Stanford Cars: %s" % (results_car))
        print("FGVC Aircraft: %s" % (results_aircraft))
        return

    TOTAL_ITERS = len(train_loader) * 100
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(student.parameters(),
                             lr=args.lr,
                             weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)
    trainer = amalgamation.CommonFeatureAmalgamator(
        logger=utils.logger.get_logger('cfl'),
        tb_writer=SummaryWriter(log_dir='run/cfl-%s' %
                                (time.asctime().replace(' ', '_'))))

    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP(every=10),
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd',
                                                 'loss_amal', 'loss_recons',
                                                 'lr')))
    trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                         callbacks=[
                             callbacks.EvalAndCkpt(model=student,
                                                   evaluator=car_evaluator,
                                                   metric_name='car_acc',
                                                   ckpt_prefix='cfl_car'),
                             callbacks.EvalAndCkpt(
                                 model=student,
                                 evaluator=aircraft_evaluator,
                                 metric_name='aircraft_acc',
                                 ckpt_prefix='cfl_aircraft'),
                         ])
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    layer_groups = [(student.fc, car_teacher.fc, aircraft_teacher.fc)]
    layer_channels = [(512, 512, 512)]

    trainer.setup(student=student,
                  teachers=[car_teacher, aircraft_teacher],
                  layer_groups=layer_groups,
                  layer_channels=layer_channels,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device,
                  on_layer_input=True,
                  weights=[1., 10., 10.])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
示例#5
0
def main():
    # PyTorch Part
    model = vision.models.segmentation.deeplabv3_resnet50(
        num_classes=1, pretrained_backbone=True)
    train_dst = vision.datasets.NYUv2(
        'data/NYUv2',
        split='train',
        target_type='depth',
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240)),
            sT.Sync(sT.RandomRotation(5), sT.RandomRotation(5)),
            sT.Multi(sT.ColorJitter(0.2, 0.2, 0.2), None),
            sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240)),
            sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip()),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambda x: x / 1000))
        ]))
    val_dst = vision.datasets.NYUv2(
        'data/NYUv2',
        split='test',
        target_type='depth',
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240)),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambda x: x / 1000))
        ]))
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=16,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.SGD(model.parameters(),
                            lr=0.01,
                            momentum=0.9,
                            weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    # KAE Part
    metric = kamal.tasks.StandardMetrics.monocular_depth()
    evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader,
                                                metric=metric,
                                                progress=False)
    task = kamal.tasks.StandardTask.monocular_depth()
    trainer = engine.trainer.BasicTrainer(
        logger=kamal.utils.logger.get_logger('nyuv2_depth_deeplab'),
        tb_writer=SummaryWriter(log_dir='run/nyuv2_depth_deeplab-%s' %
                                (time.asctime().replace(' ', '_'))))
    trainer.setup(model=model,
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                         callbacks=[
                             callbacks.EvalAndCkpt(model=model,
                                                   evaluator=evaluator,
                                                   metric_name='rmse',
                                                   metric_mode='min',
                                                   ckpt_prefix='nyuv2_depth'),
                             callbacks.VisualizeDepth(
                                 model=model,
                                 dataset=val_dst,
                                 idx_list_or_num_vis=10,
                                 max_depth=10,
                                 normalizer=kamal.utils.Normalizer(
                                     mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225],
                                     reverse=True),
                             )
                         ])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main():
    # Seg + Depth
    model = MultiTaskSegNet(out_channel_list=[13, 1])
    seg_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=13, pretrained_backbone=True)
    depth_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=1, pretrained_backbone=True)
    seg_teacher.load_state_dict(torch.load(args.seg_ckpt))
    depth_teacher.load_state_dict(torch.load(args.depth_ckpt))

    seg_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='train',
                                          target_type='semantic')
    seg_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                        split='test',
                                        target_type='semantic')
    depth_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                            split='train',
                                            target_type='depth')
    depth_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='test',
                                          target_type='depth')
    train_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_train_dst, depth_train_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240),
                    sT.RandomCrop(240)),
            sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip(),
                    sT.RandomHorizontalFlip()),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))
    val_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_val_dst, depth_val_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))

    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=16,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    confusion_matrix = metrics.ConfusionMatrix(num_classes=13,
                                               ignore_idx=255,
                                               attach_to=0)
    metric = metrics.MetricCompose(
        metric_dict={
            'acc': metrics.Accuracy(attach_to=0),
            'cm': confusion_matrix,
            'mIoU': metrics.mIoU(confusion_matrix),
            'rmse': metrics.RootMeanSquaredError(attach_to=1)
        })
    evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader,
                                                metric=metric,
                                                progress=False)

    task = [
        kamal.tasks.StandardTask.distillation(attach_to=[0, 0]),
        kamal.tasks.StandardTask.monocular_depth(attach_to=[1, 1])
    ]
    trainer = engine.trainer.KDTrainer(
        logger=kamal.utils.logger.get_logger('nyuv2_simple_kd'),
        tb_writer=SummaryWriter(log_dir='run/nyuv2_simple_kd-%s' %
                                (time.asctime().replace(' ', '_'))))
    trainer.setup(student=model,
                  teacher=[seg_teacher, depth_teacher],
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_EPOCH,
        callbacks=[
            callbacks.EvalAndCkpt(model=model,
                                  evaluator=evaluator,
                                  metric_name='rmse',
                                  metric_mode='min',
                                  ckpt_prefix='nyuv2_simple_kd'),
            callbacks.VisualizeSegmentation(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                attach_to=0,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
            callbacks.VisualizeDepth(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                max_depth=10,
                attach_to=1,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
        ])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
示例#7
0
def main():
    # Pytorch Part
    if args.dataset == 'stanford_dogs':
        num_classes = 120
        train_dst = vision.datasets.StanfordDogs('data/StanfordDogs',
                                                 split='train')
        val_dst = vision.datasets.StanfordDogs('data/StanfordDogs',
                                               split='test')
    elif args.dataset == 'cub200':
        num_classes = 200
        train_dst = vision.datasets.CUB200('data/CUB200', split='train')
        val_dst = vision.datasets.CUB200('data/CUB200', split='test')
    elif args.dataset == 'fgvc_aircraft':
        num_classes = 102
        train_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/',
                                                 split='trainval')
        val_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/',
                                               split='test')
    elif args.dataset == 'stanford_cars':
        num_classes = 196
        train_dst = vision.datasets.StanfordCars('data/StanfordCars/',
                                                 split='train')
        val_dst = vision.datasets.StanfordCars('data/StanfordCars/',
                                               split='test')
    else:
        raise NotImplementedError

    model = vision.models.classification.resnet18(num_classes=num_classes,
                                                  pretrained=args.pretrained)
    train_dst.transform = sT.Compose([
        sT.RandomResizedCrop(224),
        sT.RandomHorizontalFlip(),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_dst.transform = sT.Compose([
        sT.Resize(256),
        sT.CenterCrop(224),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=32,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * args.epochs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.SGD(model.parameters(),
                            lr=args.lr,
                            momentum=0.9,
                            weight_decay=5e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    # KAE Part
    # Predefined task & metrics
    task = kamal.tasks.StandardTask.classification()
    metric = kamal.tasks.StandardMetrics.classification()
    # Evaluator and Trainer
    evaluator = engine.evaluator.BasicEvaluator(val_loader,
                                                metric=metric,
                                                progress=True)
    trainer = engine.trainer.BasicTrainer(
        logger=kamal.utils.logger.get_logger(args.dataset),
        tb_writer=SummaryWriter(
            log_dir='run/%s-%s' %
            (args.dataset, time.asctime().replace(' ', '_'))))
    # setup trainer
    trainer.setup(model=model,
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    ckpt_callback = trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                                         callbacks=callbacks.EvalAndCkpt(
                                             model=model,
                                             evaluator=evaluator,
                                             metric_name='acc',
                                             ckpt_prefix=args.dataset))
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
    ckpt_callback.callback.final_ckpt(ckpt_dir='pretrained', add_md5=True)
示例#8
0
def main():
    # Pytorch Part
    model = vision.models.classification.cifar.wrn.wrn_40_2(num_classes=100)
    train_dst = vision.datasets.torchvision_datasets.CIFAR100(
        'data/torchdata',
        train=True,
        download=True,
        transform=sT.Compose([
            sT.RandomCrop(32, padding=4),
            sT.RandomHorizontalFlip(),
            sT.ToTensor(),
            sT.Normalize(mean=(0.4914, 0.4822, 0.4465),
                         std=(0.2023, 0.1994, 0.2010))
        ]))
    val_dst = vision.datasets.torchvision_datasets.CIFAR100(
        'data/torchdata',
        train=False,
        download=True,
        transform=sT.Compose([
            sT.ToTensor(),
            sT.Normalize(mean=(0.4914, 0.4822, 0.4465),
                         std=(0.2023, 0.1994, 0.2010))
        ]))
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=128,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=128,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.SGD(model.parameters(),
                            lr=0.1,
                            momentum=0.9,
                            weight_decay=5e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    # KAE Part
    # prepare evaluator
    metric = kamal.tasks.StandardMetrics.classification()
    evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader,
                                                metric=metric,
                                                progress=False)

    # prepare trainer
    task = kamal.tasks.StandardTask.classification()
    trainer = engine.trainer.BasicTrainer(
        logger=kamal.utils.logger.get_logger('cifar100'),
        tb_writer=SummaryWriter(log_dir='run/cifar100-%s' %
                                (time.asctime().replace(' ', '_'))))
    trainer.setup(model=model,
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)

    # add callbacks
    trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                         callbacks=callbacks.EvalAndCkpt(
                             model=model,
                             evaluator=evaluator,
                             metric_name='acc',
                             ckpt_prefix='cifar100'))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    # run
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main():
    car_train_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='train')
    car_val_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='test')

    dog_train_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='train')
    dog_val_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='test')
    cub_train_dst = vision.datasets.CUB200( '../data/CUB200', split='train')
    cub_val_dst = vision.datasets.CUB200( '../data/CUB200', split='test')

    #car_teacher = vision.models.classification.resnet18( num_classes=196, pretrained=False )
    #aircraft_teacher = vision.models.classification.resnet18( num_classes=102, pretrained=False )
    #dog_teacher = vision.models.classification.resnet18( num_classes=120, pretrained=False )
    #cub_teacher = vision.models.classification.resnet18( num_classes=200, pretrained=False )
    student = vision.models.classification.resnet18( num_classes=196+102+120+200, pretrained=False )

    #car_teacher.load_state_dict( torch.load( args.car_ckpt ) )
    #aircraft_teacher.load_state_dict( torch.load( args.aircraft_ckpt ) )
    #dog_teacher.load_state_dict( torch.load( args.dog_ckpt ) )
    #cub_teacher.load_state_dict( torch.load( args.cub_ckpt ) )

    train_transform = sT.Compose( [
                            sT.RandomResizedCrop(224),
                            sT.RandomHorizontalFlip(),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    val_transform = sT.Compose( [
                            sT.Resize(256),
                            sT.CenterCrop( 224 ),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    
    cub_train_dst.transform = dog_train_dst.transform = car_train_dst.transform = aircraft_train_dst.transform = train_transform
    cub_val_dst.transform = dog_val_dst.transform = car_val_dst.transform = aircraft_val_dst.transform = val_transform
    aircraft_train_dst.target_transform = lambda t: t+196
    dog_train_dst.target_transform = lambda t: t+196+102
    cub_train_dst.target_transform = lambda t: t+196+102+120

    car_metric =        metrics.MetricCompose(metric_dict={ 'car_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, :196],t) ) })
    aircraft_metric =   metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+102],t) ) })
    dog_metric =        metrics.MetricCompose(metric_dict={ 'dog_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102:196+102+120],t) ) })
    cub_metric =        metrics.MetricCompose(metric_dict={ 'cub_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102+120:196+102+120+200],t) ) })

    train_dst = torch.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst, dog_train_dst, cub_train_dst] )
    train_loader = torch.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=4 )
    car_loader = torch.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    aircraft_loader = torch.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    dog_loader = torch.utils.data.DataLoader( dog_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    cub_loader = torch.utils.data.DataLoader( cub_val_dst, batch_size=32, shuffle=False, num_workers=2 )

    car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric )
    aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric )
    dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric )
    cub_evaluator = engine.evaluator.BasicEvaluator( cub_loader, cub_metric )
    
    if args.ckpt is not None:
        student.load_state_dict( torch.load( args.ckpt ) )
        print("Load student model from %s"%args.ckpt)

    if args.test_only:
        results_car = car_evaluator.eval( student )
        results_aircraft = aircraft_evaluator.eval( student )
        results_dog = dog_evaluator.eval( student )
        results_cub = cub_evaluator.eval( student )
        print("Stanford Cars: %s"%( results_car ))
        print("FGVC Aircraft: %s"%( results_aircraft ))
        print("Stanford Dogs: %s"%( results_dog ))
        print("CUB200: %s"%( results_cub ))
        return 

    TOTAL_ITERS=len(train_loader) * 100
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    optim = torch.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
    task = tasks.StandardTask.classification()
    trainer = engine.trainer.BasicTrainer( 
        logger=utils.logger.get_logger('scratch-4'), 
        tb_writer=SummaryWriter( log_dir='run/scratch-4-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr')))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[
            callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='cfl_car'),
            callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='cfl_aircraft'),
            callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='cfl_dog'),
            callbacks.EvalAndCkpt(model=student, evaluator=cub_evaluator, metric_name='cub_acc', ckpt_prefix='cfl_cub'),
        ] )
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    trainer.setup( model=student,
                   task=task,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device )
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)