def main():
    # Seg + Depth
    model = MultiTaskSegNet(out_channel_list=[13, 1])
    seg_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=13, pretrained_backbone=True)
    depth_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=1, pretrained_backbone=True)
    seg_teacher.load_state_dict(torch.load(args.seg_ckpt))
    depth_teacher.load_state_dict(torch.load(args.depth_ckpt))

    seg_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='train',
                                          target_type='semantic')
    seg_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                        split='test',
                                        target_type='semantic')
    depth_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                            split='train',
                                            target_type='depth')
    depth_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='test',
                                          target_type='depth')
    train_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_train_dst, depth_train_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240),
                    sT.RandomCrop(240)),
            sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip(),
                    sT.RandomHorizontalFlip()),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))
    val_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_val_dst, depth_val_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))

    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=16,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    confusion_matrix = metrics.ConfusionMatrix(num_classes=13,
                                               ignore_idx=255,
                                               attach_to=0)
    metric = metrics.MetricCompose(
        metric_dict={
            'acc': metrics.Accuracy(attach_to=0),
            'cm': confusion_matrix,
            'mIoU': metrics.mIoU(confusion_matrix),
            'rmse': metrics.RootMeanSquaredError(attach_to=1)
        })
    evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader,
                                                metric=metric,
                                                progress=False)

    task = [
        kamal.tasks.StandardTask.distillation(attach_to=[0, 0]),
        kamal.tasks.StandardTask.monocular_depth(attach_to=[1, 1])
    ]
    trainer = engine.trainer.KDTrainer(
        logger=kamal.utils.logger.get_logger('nyuv2_simple_kd'),
        tb_writer=SummaryWriter(log_dir='run/nyuv2_simple_kd-%s' %
                                (time.asctime().replace(' ', '_'))))
    trainer.setup(student=model,
                  teacher=[seg_teacher, depth_teacher],
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_EPOCH,
        callbacks=[
            callbacks.EvalAndCkpt(model=model,
                                  evaluator=evaluator,
                                  metric_name='rmse',
                                  metric_mode='min',
                                  ckpt_prefix='nyuv2_simple_kd'),
            callbacks.VisualizeSegmentation(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                attach_to=0,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
            callbacks.VisualizeDepth(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                max_depth=10,
                attach_to=1,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
        ])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
예제 #2
0
def main():
    car_train_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                                 split='train')
    car_val_dst = vision.datasets.StanfordCars('../data/StanfordCars',
                                               split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                      split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft',
                                                    split='test')
    car_teacher = vision.models.classification.resnet18(num_classes=196,
                                                        pretrained=False)
    aircraft_teacher = vision.models.classification.resnet18(num_classes=102,
                                                             pretrained=False)
    student = vision.models.classification.resnet18(num_classes=196 + 102,
                                                    pretrained=False)
    car_teacher.load_state_dict(torch.load(args.car_ckpt))
    aircraft_teacher.load_state_dict(torch.load(args.aircraft_ckpt))
    train_transform = sT.Compose([
        sT.RandomResizedCrop(224),
        sT.RandomHorizontalFlip(),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_transform = sT.Compose([
        sT.Resize(256),
        sT.CenterCrop(224),
        sT.ToTensor(),
        sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    car_train_dst.transform = aircraft_train_dst.transform = train_transform
    car_val_dst.transform = aircraft_val_dst.transform = val_transform

    car_metric = metrics.MetricCompose(
        metric_dict={
            'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))
        })
    aircraft_metric = metrics.MetricCompose(metric_dict={
        'aircraft_acc':
        metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:], t))
    })

    train_dst = torch.utils.data.ConcatDataset(
        [car_train_dst, aircraft_train_dst])
    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=4)
    car_loader = torch.utils.data.DataLoader(car_val_dst,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=4)
    aircraft_loader = torch.utils.data.DataLoader(aircraft_val_dst,
                                                  batch_size=32,
                                                  shuffle=False,
                                                  num_workers=4)

    car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric)
    aircraft_evaluator = engine.evaluator.BasicEvaluator(
        aircraft_loader, aircraft_metric)

    if args.ckpt is not None:
        student.load_state_dict(torch.load(args.ckpt))
        print("Load student model from %s" % args.ckpt)
    if args.test_only:
        results_car = car_evaluator.eval(student)
        results_aircraft = aircraft_evaluator.eval(student)
        print("Stanford Cars: %s" % (results_car))
        print("FGVC Aircraft: %s" % (results_aircraft))
        return

    TOTAL_ITERS = len(train_loader) * 100
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(student.parameters(),
                             lr=args.lr,
                             weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)
    trainer = amalgamation.CommonFeatureAmalgamator(
        logger=utils.logger.get_logger('cfl'),
        tb_writer=SummaryWriter(log_dir='run/cfl-%s' %
                                (time.asctime().replace(' ', '_'))))

    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP(every=10),
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd',
                                                 'loss_amal', 'loss_recons',
                                                 'lr')))
    trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH,
                         callbacks=[
                             callbacks.EvalAndCkpt(model=student,
                                                   evaluator=car_evaluator,
                                                   metric_name='car_acc',
                                                   ckpt_prefix='cfl_car'),
                             callbacks.EvalAndCkpt(
                                 model=student,
                                 evaluator=aircraft_evaluator,
                                 metric_name='aircraft_acc',
                                 ckpt_prefix='cfl_aircraft'),
                         ])
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    layer_groups = [(student.fc, car_teacher.fc, aircraft_teacher.fc)]
    layer_channels = [(512, 512, 512)]

    trainer.setup(student=student,
                  teachers=[car_teacher, aircraft_teacher],
                  layer_groups=layer_groups,
                  layer_channels=layer_channels,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device,
                  on_layer_input=True,
                  weights=[1., 10., 10.])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main():
    car_train_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='train')
    car_val_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='test')
    aircraft_train_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='trainval')
    aircraft_val_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='test')

    dog_train_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='train')
    dog_val_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='test')
    cub_train_dst = vision.datasets.CUB200( '../data/CUB200', split='train')
    cub_val_dst = vision.datasets.CUB200( '../data/CUB200', split='test')

    #car_teacher = vision.models.classification.resnet18( num_classes=196, pretrained=False )
    #aircraft_teacher = vision.models.classification.resnet18( num_classes=102, pretrained=False )
    #dog_teacher = vision.models.classification.resnet18( num_classes=120, pretrained=False )
    #cub_teacher = vision.models.classification.resnet18( num_classes=200, pretrained=False )
    student = vision.models.classification.resnet18( num_classes=196+102+120+200, pretrained=False )

    #car_teacher.load_state_dict( torch.load( args.car_ckpt ) )
    #aircraft_teacher.load_state_dict( torch.load( args.aircraft_ckpt ) )
    #dog_teacher.load_state_dict( torch.load( args.dog_ckpt ) )
    #cub_teacher.load_state_dict( torch.load( args.cub_ckpt ) )

    train_transform = sT.Compose( [
                            sT.RandomResizedCrop(224),
                            sT.RandomHorizontalFlip(),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    val_transform = sT.Compose( [
                            sT.Resize(256),
                            sT.CenterCrop( 224 ),
                            sT.ToTensor(),
                            sT.Normalize( mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225] )
                        ] )
    
    cub_train_dst.transform = dog_train_dst.transform = car_train_dst.transform = aircraft_train_dst.transform = train_transform
    cub_val_dst.transform = dog_val_dst.transform = car_val_dst.transform = aircraft_val_dst.transform = val_transform
    aircraft_train_dst.target_transform = lambda t: t+196
    dog_train_dst.target_transform = lambda t: t+196+102
    cub_train_dst.target_transform = lambda t: t+196+102+120

    car_metric =        metrics.MetricCompose(metric_dict={ 'car_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, :196],t) ) })
    aircraft_metric =   metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+102],t) ) })
    dog_metric =        metrics.MetricCompose(metric_dict={ 'dog_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102:196+102+120],t) ) })
    cub_metric =        metrics.MetricCompose(metric_dict={ 'cub_acc':      metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102+120:196+102+120+200],t) ) })

    train_dst = torch.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst, dog_train_dst, cub_train_dst] )
    train_loader = torch.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=4 )
    car_loader = torch.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    aircraft_loader = torch.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    dog_loader = torch.utils.data.DataLoader( dog_val_dst, batch_size=32, shuffle=False, num_workers=2 )
    cub_loader = torch.utils.data.DataLoader( cub_val_dst, batch_size=32, shuffle=False, num_workers=2 )

    car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric )
    aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric )
    dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric )
    cub_evaluator = engine.evaluator.BasicEvaluator( cub_loader, cub_metric )
    
    if args.ckpt is not None:
        student.load_state_dict( torch.load( args.ckpt ) )
        print("Load student model from %s"%args.ckpt)

    if args.test_only:
        results_car = car_evaluator.eval( student )
        results_aircraft = aircraft_evaluator.eval( student )
        results_dog = dog_evaluator.eval( student )
        results_cub = cub_evaluator.eval( student )
        print("Stanford Cars: %s"%( results_car ))
        print("FGVC Aircraft: %s"%( results_aircraft ))
        print("Stanford Dogs: %s"%( results_dog ))
        print("CUB200: %s"%( results_cub ))
        return 

    TOTAL_ITERS=len(train_loader) * 100
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    optim = torch.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
    task = tasks.StandardTask.classification()
    trainer = engine.trainer.BasicTrainer( 
        logger=utils.logger.get_logger('scratch-4'), 
        tb_writer=SummaryWriter( log_dir='run/scratch-4-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr')))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[
            callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='cfl_car'),
            callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='cfl_aircraft'),
            callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='cfl_dog'),
            callbacks.EvalAndCkpt(model=student, evaluator=cub_evaluator, metric_name='cub_acc', ckpt_prefix='cfl_cub'),
        ] )
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))

    trainer.setup( model=student,
                   task=task,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device )
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)