Beispiel #1
0
def main():
    car_train_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                                 split='train')
    car_val_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                               split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                      split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                    split='test')
    car_teacher = vision.models.classification.resnet18(num_classes=196,
                                                        pretrained=False)
    aircraft_teacher = vision.models.classification.resnet18(num_classes=102,
                                                             pretrained=False)
    student = vision.models.classification.resnet18(num_classes=196 + 102,
                                                    pretrained=False)
    car_teacher.load_state_dict(torch.load(args.car_ckpt))
    aircraft_teacher.load_state_dict(torch.load(args.aircraft_ckpt))
    train_transform = sT.Compose([
        sT.RandomResizedCrop(224),
        sT.RandomHorizontalFlip(),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_transform = sT.Compose([
        sT.Resize(256),
        sT.CenterCrop(224),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    car_train_dst.transform = aircraft_train_dst.transform = train_transform
    car_val_dst.transform = aircraft_val_dst.transform = val_transform

    car_metric = metrics.MetricCompose(
        metric_dict={
            'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))
        })
    aircraft_metric = metrics.MetricCompose(metric_dict={
        'aircraft_acc':
        metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:], t))
    })

    train_dst = torch.utils.data.ConcatDataset(
        [car_train_dst, aircraft_train_dst])
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=4)
    car_loader = torch.utils.data.DataLoader(car_val_dst,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=4)
    aircraft_loader = torch.utils.data.DataLoader(aircraft_val_dst,
                                                  batch_size=32,
                                                  shuffle=False,
                                                  num_workers=4)

    car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric)
    aircraft_evaluator = engine.evaluator.BasicEvaluator(
        aircraft_loader, aircraft_metric)

    if args.ckpt is not None:
        student.load_state_dict(torch.load(args.ckpt))
        print("Load student model from %s" % args.ckpt)
    if args.test_only:
        results_car = car_evaluator.eval(student)
        results_aircraft = aircraft_evaluator.eval(student)
        print("Stanford Cars: %s" % (results_car))
        print("FGVC Aircraft: %s" % (results_aircraft))
        return

    TOTAL_ITERS = len(train_loader) * 100
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(student.parameters(),
                             lr=args.lr,
                             weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)
    trainer = amalgamation.CommonFeatureAmalgamator(
        logger=utils.logger.get_logger('cfl'),
        tb_writer=SummaryWriter(log_dir='run/cfl-%s' %
                                (time.asctime().replace(' ', '_'))))

    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP(every=10),
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd',
                                                 'loss_amal', 'loss_recons',
                                                 'lr')))
    trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                         callbacks=[
                             callbacks.EvalAndCkpt(model=student,
                                                   evaluator=car_evaluator,
                                                   metric_name='car_acc',
                                                   ckpt_prefix='cfl_car'),
                             callbacks.EvalAndCkpt(
                                 model=student,
                                 evaluator=aircraft_evaluator,
                                 metric_name='aircraft_acc',
                                 ckpt_prefix='cfl_aircraft'),
                         ])
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    layer_groups = [(student.fc, car_teacher.fc, aircraft_teacher.fc)]
    layer_channels = [(512, 512, 512)]

    trainer.setup(student=student,
                  teachers=[car_teacher, aircraft_teacher],
                  layer_groups=layer_groups,
                  layer_channels=layer_channels,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device,
                  on_layer_input=True,
                  weights=[1., 10., 10.])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main():
    # PyTorch Part
    num_classes = 11
    model = vision.models.segmentation.deeplabv3_resnet50(num_classes=num_classes, pretrained_backbone=True)
    train_dst = vision.datasets.CamVid( 
        'data/CamVid11', split='trainval', transforms=sT.Compose([
            sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)),
            sT.Sync(  sT.RandomRotation(5),  sT.RandomRotation(5)),
            sT.Multi( sT.ColorJitter(0.2, 0.2, 0.2), None),
            sT.Sync(  sT.RandomCrop(240),  sT.RandomCrop(240)),
            sT.Sync(  sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip() ),
            sT.Multi( sT.ToTensor(), sT.ToTensor( normalize=False, dtype=torch.long) ),
            sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) )
        ]) )
    val_dst = vision.datasets.CamVid( 
        'data/CamVid11', split='test', transforms=sT.Compose([
            sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)),
            sT.Multi( sT.ToTensor(),  sT.ToTensor( normalize=False, dtype=torch.long ) ),
            sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) )
        ]) )
    
    train_loader = torch.utils.data.DataLoader( train_dst, batch_size=16, shuffle=True, num_workers=4 )
    val_loader = torch.utils.data.DataLoader( val_dst, batch_size=16, num_workers=4 )
    TOTAL_ITERS=len(train_loader) * 200
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    optim = torch.optim.SGD( model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 )
    sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )

    # KAE Part
    metric = kamal.tasks.StandardMetrics.segmentation(num_classes=num_classes)
    evaluator = engine.evaluator.BasicEvaluator( dataloader=val_loader, metric=metric, progress=False )

    task = kamal.tasks.StandardTask.segmentation()
    trainer = engine.trainer.BasicTrainer( 
        logger=kamal.utils.logger.get_logger('camvid_seg_deeplab'), 
        tb_writer=SummaryWriter( log_dir='run/camvid_seg_deeplab-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    trainer.setup( model=model, 
                   task=task,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device )

    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[ 
            callbacks.EvalAndCkpt(model=model, evaluator=evaluator, metric_name='miou', ckpt_prefix='camvid_seg_deeplabv3_resnet50'),
            callbacks.VisualizeSegmentation(
                model=model,
                dataset=val_dst, 
                idx_list_or_num_vis=10,
                normalizer=kamal.utils.Normalizer( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], reverse=True),
            )])

    #import matplotlib.pyplot as plt
    #lr_finder = kamal.engine.lr_finder.LRFinder()
    #best_lr = lr_finder.find( optim, model, trainer, lr_range=[1e-8, 1.0], max_iter=100, smooth_momentum=0.9 )
    #fig = lr_finder.plot(polyfit=4)
    #plt.savefig('lr_finder_deeplab.png')
    #lr_finder.adjust_learning_rate(optim, best_lr)

    trainer.run( start_iter=0, max_iter=TOTAL_ITERS )
Beispiel #3
0
def main():
    # PyTorch Part
    model = vision.models.segmentation.deeplabv3_resnet50(
        num_classes=1, pretrained_backbone=True)
    train_dst = vision.datasets.NYUv2(
        'data/NYUv2',
        split='train',
        target_type='depth',
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240)),
            sT.Sync(sT.RandomRotation(5), sT.RandomRotation(5)),
            sT.Multi(sT.ColorJitter(0.2, 0.2, 0.2), None),
            sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240)),
            sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip()),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambda x: x / 1000))
        ]))
    val_dst = vision.datasets.NYUv2(
        'data/NYUv2',
        split='test',
        target_type='depth',
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240)),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambda x: x / 1000))
        ]))
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=16,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.SGD(model.parameters(),
                            lr=0.01,
                            momentum=0.9,
                            weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    # KAE Part
    metric = kamal.tasks.StandardMetrics.monocular_depth()
    evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader,
                                                metric=metric,
                                                progress=False)
    task = kamal.tasks.StandardTask.monocular_depth()
    trainer = engine.trainer.BasicTrainer(
        logger=kamal.utils.logger.get_logger('nyuv2_depth_deeplab'),
        tb_writer=SummaryWriter(log_dir='run/nyuv2_depth_deeplab-%s' %
                                (time.asctime().replace(' ', '_'))))
    trainer.setup(model=model,
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                         callbacks=[
                             callbacks.EvalAndCkpt(model=model,
                                                   evaluator=evaluator,
                                                   metric_name='rmse',
                                                   metric_mode='min',
                                                   ckpt_prefix='nyuv2_depth'),
                             callbacks.VisualizeDepth(
                                 model=model,
                                 dataset=val_dst,
                                 idx_list_or_num_vis=10,
                                 max_depth=10,
                                 normalizer=kamal.utils.Normalizer(
                                     mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225],
                                     reverse=True),
                             )
                         ])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main():
    # Seg + Depth
    model = MultiTaskSegNet(out_channel_list=[13, 1])
    seg_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=13, pretrained_backbone=True)
    depth_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=1, pretrained_backbone=True)
    seg_teacher.load_state_dict(torch.load(args.seg_ckpt))
    depth_teacher.load_state_dict(torch.load(args.depth_ckpt))

    seg_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='train',
                                          target_type='semantic')
    seg_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                        split='test',
                                        target_type='semantic')
    depth_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                            split='train',
                                            target_type='depth')
    depth_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='test',
                                          target_type='depth')
    train_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_train_dst, depth_train_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240),
                    sT.RandomCrop(240)),
            sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip(),
                    sT.RandomHorizontalFlip()),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))
    val_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_val_dst, depth_val_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))

    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=16,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    confusion_matrix = metrics.ConfusionMatrix(num_classes=13,
                                               ignore_idx=255,
                                               attach_to=0)
    metric = metrics.MetricCompose(
        metric_dict={
            'acc': metrics.Accuracy(attach_to=0),
            'cm': confusion_matrix,
            'mIoU': metrics.mIoU(confusion_matrix),
            'rmse': metrics.RootMeanSquaredError(attach_to=1)
        })
    evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader,
                                                metric=metric,
                                                progress=False)

    task = [
        kamal.tasks.StandardTask.distillation(attach_to=[0, 0]),
        kamal.tasks.StandardTask.monocular_depth(attach_to=[1, 1])
    ]
    trainer = engine.trainer.KDTrainer(
        logger=kamal.utils.logger.get_logger('nyuv2_simple_kd'),
        tb_writer=SummaryWriter(log_dir='run/nyuv2_simple_kd-%s' %
                                (time.asctime().replace(' ', '_'))))
    trainer.setup(student=model,
                  teacher=[seg_teacher, depth_teacher],
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_EPOCH,
        callbacks=[
            callbacks.EvalAndCkpt(model=model,
                                  evaluator=evaluator,
                                  metric_name='rmse',
                                  metric_mode='min',
                                  ckpt_prefix='nyuv2_simple_kd'),
            callbacks.VisualizeSegmentation(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                attach_to=0,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
            callbacks.VisualizeDepth(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                max_depth=10,
                attach_to=1,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
        ])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--t_model_path', type=str, default='../pretrained/cifar100_wrn_40_2.pth')
    parser.add_argument('--T', '--temperature', type=float,
                        default=4.0, help='temperature for KD distillation')
    parser.add_argument('-r', '--gamma', type=float, default=1)
    parser.add_argument('-a', '--alpha', type=float, default=None)
    parser.add_argument('-b', '--beta', type=float, default=None)
    parser.add_argument('--epochs', type=int, default=240)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_workers', type=int,
                        default=8, help='num of workers to use')

    parser.add_argument('--distill', type=str, default='kd', choices=[
                        'kd', 'hint', 'attention', 'sp', 'cc', 'vid', 'svd', 'pkt', 'nst', 'rkd'])
    # dataset
    parser.add_argument('--data_root', type=str, default='../data/torchdata')
    parser.add_argument('--img_size', type=int, default=32,
                        help='image size of datasets')
    # hint layer
    parser.add_argument('--hint_layer', default=2,
                        type=int, choices=[0, 1, 2, 3, 4])
    # cc embed dim
    parser.add_argument('--embed_dim', default=128,
                        type=int, help='feature dimension')

    args = parser.parse_args()

    # prepare data
    cifar100_train = CIFAR100(args.data_root, train=True, download=True,
                              transform=T.Compose([
                                  T.RandomCrop(32, padding=4),
                                  T.RandomHorizontalFlip(),
                                  T.ToTensor(),
                                  T.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                              std=[0.2675, 0.2565, 0.2761])])
                              )

    train_loader = torch.utils.data.DataLoader(
        cifar100_train, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        CIFAR100(args.data_root, train=False, download=True,
                 transform=T.Compose([
                     T.ToTensor(),
                     T.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                 std=[0.2675, 0.2565, 0.2761])])
                 ), batch_size=args.batch_size, num_workers=args.num_workers
    )
    # prepare model
    teacher = vision.models.classification.cifar.wrn.wrn_40_2(num_classes=100)
    student = vision.models.classification.cifar.wrn.wrn_16_2(num_classes=100)
    teacher.load_state_dict(torch.load(args.t_model_path))
    print('[!] teacher loads weights from %s' % (args.t_model_path))

    # Teacher eval
    evaluator = engine.evaluator.BasicEvaluator(val_loader, metric=metrics.StandardTaskMetrics.classification())
    teacher_scores = evaluator.eval(teacher)
    print('[TEACHER] Acc=%.4f' % (teacher_scores['acc']))

    # hook module feature
    out_flags = [True, True, True, True, False]
    tea_hooks = []
    tea_layers = [teacher.conv1, teacher.block1,
                  teacher.block2, teacher.block3, teacher.fc]
    for module in tea_layers:
        hookfeat = engine.hooks.FeatureHook(module)
        hookfeat.register()
        tea_hooks.append(hookfeat)

    stu_hooks = []
    stu_layers = [student.conv1, student.block1,
                  student.block2, student.block3, student.fc]
    for module in stu_layers:
        hookfeat = engine.hooks.FeatureHook(module)
        hookfeat.register()
        stu_hooks.append(hookfeat)

    # distiller setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger = utils.logger.get_logger('distill_%s' % (args.distill))
    tb_writer = SummaryWriter(log_dir='run/distill_%s-%s' %
                        (args.distill, time.asctime().replace(' ', '_')))
    if args.distill == 'kd':
        optimizer = torch.optim.SGD(student.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.KDDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader,
                        optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, device=device)
    if args.distill == 'attention':
        optimizer = torch.optim.SGD(student.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.AttentionDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, 
                        gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'nst':
        optimizer = torch.optim.SGD(student.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.NSTDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, 
                        gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'sp':
        optimizer = torch.optim.SGD(student.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.SPDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, 
                        gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'rkd':
        optimizer = torch.optim.SGD(student.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.RKDDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, 
                        gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'pkt':
        optimizer = torch.optim.SGD(student.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.PKTDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, 
                        gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'svd':
        optimizer = torch.optim.SGD(student.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.SVDDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, 
                        gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'vid':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = torch.randn(1, 3, args.img_size, args.img_size)
        data = data.to(device)
        student, teacher = student.to(device), teacher.to(device)
        teacher(data)
        student(data)
        s_n = [f.feat_out.shape[1] if flag else f.feat_in.shape[1]
               for (f, flag) in zip(stu_hooks[1:-1], out_flags)]
        t_n = [f.feat_out.shape[1] if flag else f.feat_in.shape[1]
               for (f, flag) in zip(tea_hooks[1:-1], out_flags)]
        train_list = nn.ModuleList([student])
        VIDRegressor_l = nn.ModuleList()
        for s, t in zip(s_n, t_n):
            vid_r = slim.VIDRegressor(s, t, t)
            VIDRegressor_l.append(vid_r)
            train_list.append(vid_r)
        optimizer = torch.optim.SGD(train_list.parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.VIDDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, regressor_l=VIDRegressor_l, T=args.T, 
                        gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'hint':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = torch.randn(1, 3, args.img_size, args.img_size)
        data = data.to(device)
        student, teacher = student.to(device), teacher.to(device)
        teacher(data), student(data)
        feat_t = [f.feat_out if flag else f.feat_in for (
            f, flag) in zip(tea_hooks, out_flags)]
        feat_s = [f.feat_out if flag else f.feat_in for (
            f, flag) in zip(stu_hooks, out_flags)]
        fitnet = slim.Regressor(
            feat_s[args.hint_layer].shape, feat_t[args.hint_layer].shape)
        optimizer = torch.optim.SGD(nn.ModuleList(
            [student, fitnet]).parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.HintDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, regressor=fitnet, hint_layer=args.hint_layer,
                                                 T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    if args.distill == 'cc':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = torch.randn(1, 3, args.img_size, args.img_size)
        data = data.to(device)
        student, teacher = student.to(device), teacher.to(device)
        student, teacher = student.to(device), teacher.to(device)
        teacher(data), student(data)
        feat_t = [f.feat_out if flag else f.feat_in for (
            f, flag) in zip(tea_hooks, out_flags)]
        feat_s = [f.feat_out if flag else f.feat_in for (
            f, flag) in zip(stu_hooks, out_flags)]
        embed_s = slim.LinearEmbed(feat_s[-1].shape[1], args.embed_dim)
        embed_t = slim.LinearEmbed(feat_t[-1].shape[1], args.embed_dim)
        optimizer = torch.optim.SGD(nn.ModuleList([student, embed_s, embed_t]).parameters(
        ), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        distiller = slim.CCDistiller(logger, tb_writer)
        distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, embed_s=embed_s, embed_t=embed_t,
                                               T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, len(train_loader)*args.epochs)
    distiller.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kld', 'loss_ce', 'loss_additional', 'lr')))
    distiller.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=callbacks.EvalAndCkpt(model=student, evaluator=evaluator, metric_name='acc', ckpt_prefix=args.distill) )
    distiller.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[scheduler]))
    distiller.run(start_iter=0, max_iter=len(train_loader)*args.epochs)
Beispiel #6
0
def main():
    # Pytorch Part
    if args.dataset == 'stanford_dogs':
        num_classes = 120
        train_dst = vision.datasets.StanfordDogs('data/StanfordDogs',
                                                 split='train')
        val_dst = vision.datasets.StanfordDogs('data/StanfordDogs',
                                               split='test')
    elif args.dataset == 'cub200':
        num_classes = 200
        train_dst = vision.datasets.CUB200('data/CUB200', split='train')
        val_dst = vision.datasets.CUB200('data/CUB200', split='test')
    elif args.dataset == 'fgvc_aircraft':
        num_classes = 102
        train_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/',
                                                 split='trainval')
        val_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/',
                                               split='test')
    elif args.dataset == 'stanford_cars':
        num_classes = 196
        train_dst = vision.datasets.StanfordCars('data/StanfordCars/',
                                                 split='train')
        val_dst = vision.datasets.StanfordCars('data/StanfordCars/',
                                               split='test')
    else:
        raise NotImplementedError

    model = vision.models.classification.resnet18(num_classes=num_classes,
                                                  pretrained=args.pretrained)
    train_dst.transform = sT.Compose([
        sT.RandomResizedCrop(224),
        sT.RandomHorizontalFlip(),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_dst.transform = sT.Compose([
        sT.Resize(256),
        sT.CenterCrop(224),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=32,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * args.epochs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.SGD(model.parameters(),
                            lr=args.lr,
                            momentum=0.9,
                            weight_decay=5e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    # KAE Part
    # Predefined task & metrics
    task = kamal.tasks.StandardTask.classification()
    metric = kamal.tasks.StandardMetrics.classification()
    # Evaluator and Trainer
    evaluator = engine.evaluator.BasicEvaluator(val_loader,
                                                metric=metric,
                                                progress=True)
    trainer = engine.trainer.BasicTrainer(
        logger=kamal.utils.logger.get_logger(args.dataset),
        tb_writer=SummaryWriter(
            log_dir='run/%s-%s' %
            (args.dataset, time.asctime().replace(' ', '_'))))
    # setup trainer
    trainer.setup(model=model,
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    ckpt_callback = trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                                         callbacks=callbacks.EvalAndCkpt(
                                             model=model,
                                             evaluator=evaluator,
                                             metric_name='acc',
                                             ckpt_prefix=args.dataset))
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
    ckpt_callback.callback.final_ckpt(ckpt_dir='pretrained', add_md5=True)
def main():
    car_train_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='train')
    car_val_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='test')

    dog_train_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='train')
    dog_val_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='test')
    cub_train_dst = vision.datasets.CUB200( '../data/CUB200', split='train')
    cub_val_dst = vision.datasets.CUB200( '../data/CUB200', split='test')

    #car_teacher = vision.models.classification.resnet18( num_classes=196, pretrained=False )
    #aircraft_teacher = vision.models.classification.resnet18( num_classes=102, pretrained=False )
    #dog_teacher = vision.models.classification.resnet18( num_classes=120, pretrained=False )
    #cub_teacher = vision.models.classification.resnet18( num_classes=200, pretrained=False )
    student = vision.models.classification.resnet18( num_classes=196+102+120+200, pretrained=False )

    #car_teacher.load_state_dict( torch.load( args.car_ckpt ) )
    #aircraft_teacher.load_state_dict( torch.load( args.aircraft_ckpt ) )
    #dog_teacher.load_state_dict( torch.load( args.dog_ckpt ) )
    #cub_teacher.load_state_dict( torch.load( args.cub_ckpt ) )

    train_transform = sT.Compose( [
                            sT.RandomResizedCrop(224),
                            sT.RandomHorizontalFlip(),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    val_transform = sT.Compose( [
                            sT.Resize(256),
                            sT.CenterCrop( 224 ),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    
    cub_train_dst.transform = dog_train_dst.transform = car_train_dst.transform = aircraft_train_dst.transform = train_transform
    cub_val_dst.transform = dog_val_dst.transform = car_val_dst.transform = aircraft_val_dst.transform = val_transform
    aircraft_train_dst.target_transform = lambda t: t+196
    dog_train_dst.target_transform = lambda t: t+196+102
    cub_train_dst.target_transform = lambda t: t+196+102+120

    car_metric =        metrics.MetricCompose(metric_dict={ 'car_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, :196],t) ) })
    aircraft_metric =   metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+102],t) ) })
    dog_metric =        metrics.MetricCompose(metric_dict={ 'dog_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102:196+102+120],t) ) })
    cub_metric =        metrics.MetricCompose(metric_dict={ 'cub_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102+120:196+102+120+200],t) ) })

    train_dst = torch.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst, dog_train_dst, cub_train_dst] )
    train_loader = torch.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=4 )
    car_loader = torch.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    aircraft_loader = torch.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    dog_loader = torch.utils.data.DataLoader( dog_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    cub_loader = torch.utils.data.DataLoader( cub_val_dst, batch_size=32, shuffle=False, num_workers=2 )

    car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric )
    aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric )
    dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric )
    cub_evaluator = engine.evaluator.BasicEvaluator( cub_loader, cub_metric )
    
    if args.ckpt is not None:
        student.load_state_dict( torch.load( args.ckpt ) )
        print("Load student model from %s"%args.ckpt)

    if args.test_only:
        results_car = car_evaluator.eval( student )
        results_aircraft = aircraft_evaluator.eval( student )
        results_dog = dog_evaluator.eval( student )
        results_cub = cub_evaluator.eval( student )
        print("Stanford Cars: %s"%( results_car ))
        print("FGVC Aircraft: %s"%( results_aircraft ))
        print("Stanford Dogs: %s"%( results_dog ))
        print("CUB200: %s"%( results_cub ))
        return 

    TOTAL_ITERS=len(train_loader) * 100
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    optim = torch.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
    task = tasks.StandardTask.classification()
    trainer = engine.trainer.BasicTrainer( 
        logger=utils.logger.get_logger('scratch-4'), 
        tb_writer=SummaryWriter( log_dir='run/scratch-4-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr')))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[
            callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='cfl_car'),
            callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='cfl_aircraft'),
            callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='cfl_dog'),
            callbacks.EvalAndCkpt(model=student, evaluator=cub_evaluator, metric_name='cub_acc', ckpt_prefix='cfl_cub'),
        ] )
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    trainer.setup( model=student,
                   task=task,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device )
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)