def main():
    # PyTorch Part
    num_classes = 11
    model = vision.models.segmentation.deeplabv3_resnet50(num_classes=num_classes, pretrained_backbone=True)
    train_dst = vision.datasets.CamVid( 
        'data/CamVid11', split='trainval', transforms=sT.Compose([
            sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)),
            sT.Sync(  sT.RandomRotation(5),  sT.RandomRotation(5)),
            sT.Multi( sT.ColorJitter(0.2, 0.2, 0.2), None),
            sT.Sync(  sT.RandomCrop(240),  sT.RandomCrop(240)),
            sT.Sync(  sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip() ),
            sT.Multi( sT.ToTensor(), sT.ToTensor( normalize=False, dtype=torch.long) ),
            sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) )
        ]) )
    val_dst = vision.datasets.CamVid( 
        'data/CamVid11', split='test', transforms=sT.Compose([
            sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)),
            sT.Multi( sT.ToTensor(),  sT.ToTensor( normalize=False, dtype=torch.long ) ),
            sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) )
        ]) )
    
    train_loader = torch.utils.data.DataLoader( train_dst, batch_size=16, shuffle=True, num_workers=4 )
    val_loader = torch.utils.data.DataLoader( val_dst, batch_size=16, num_workers=4 )
    TOTAL_ITERS=len(train_loader) * 200
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    optim = torch.optim.SGD( model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 )
    sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )

    # KAE Part
    metric = kamal.tasks.StandardMetrics.segmentation(num_classes=num_classes)
    evaluator = engine.evaluator.BasicEvaluator( dataloader=val_loader, metric=metric, progress=False )

    task = kamal.tasks.StandardTask.segmentation()
    trainer = engine.trainer.BasicTrainer( 
        logger=kamal.utils.logger.get_logger('camvid_seg_deeplab'), 
        tb_writer=SummaryWriter( log_dir='run/camvid_seg_deeplab-%s'%( time.asctime().replace( ' ', '_' ) ) ) 
    )
    trainer.setup( model=model, 
                   task=task,
                   dataloader=train_loader,
                   optimizer=optim,
                   device=device )

    trainer.add_callback( 
        engine.DefaultEvents.AFTER_STEP(every=10), 
        callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback( 
        engine.DefaultEvents.AFTER_EPOCH, 
        callbacks=[ 
            callbacks.EvalAndCkpt(model=model, evaluator=evaluator, metric_name='miou', ckpt_prefix='camvid_seg_deeplabv3_resnet50'),
            callbacks.VisualizeSegmentation(
                model=model,
                dataset=val_dst, 
                idx_list_or_num_vis=10,
                normalizer=kamal.utils.Normalizer( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], reverse=True),
            )])

    #import matplotlib.pyplot as plt
    #lr_finder = kamal.engine.lr_finder.LRFinder()
    #best_lr = lr_finder.find( optim, model, trainer, lr_range=[1e-8, 1.0], max_iter=100, smooth_momentum=0.9 )
    #fig = lr_finder.plot(polyfit=4)
    #plt.savefig('lr_finder_deeplab.png')
    #lr_finder.adjust_learning_rate(optim, best_lr)

    trainer.run( start_iter=0, max_iter=TOTAL_ITERS )
def main():
    # Seg + Depth
    model = MultiTaskSegNet(out_channel_list=[13, 1])
    seg_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=13, pretrained_backbone=True)
    depth_teacher = vision.models.segmentation.segnet_vgg16_bn(
        num_classes=1, pretrained_backbone=True)
    seg_teacher.load_state_dict(torch.load(args.seg_ckpt))
    depth_teacher.load_state_dict(torch.load(args.depth_ckpt))

    seg_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='train',
                                          target_type='semantic')
    seg_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                        split='test',
                                        target_type='semantic')
    depth_train_dst = vision.datasets.NYUv2('../data/NYUv2',
                                            split='train',
                                            target_type='depth')
    depth_val_dst = vision.datasets.NYUv2('../data/NYUv2',
                                          split='test',
                                          target_type='depth')
    train_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_train_dst, depth_train_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240),
                    sT.RandomCrop(240)),
            sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip(),
                    sT.RandomHorizontalFlip()),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))
    val_dst = vision.datasets.LabelConcatDataset(
        datasets=[seg_val_dst, depth_val_dst],
        transforms=sT.Compose([
            sT.Multi(sT.Resize(240), sT.Resize(240,
                                               interpolation=Image.NEAREST),
                     sT.Resize(240)),
            sT.Multi(sT.ToTensor(),
                     sT.ToTensor(normalize=False, dtype=torch.long),
                     sT.ToTensor(normalize=False, dtype=torch.float)),
            sT.Multi(
                sT.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
                sT.Lambda(lambd=lambda x: x.squeeze()),
                sT.Lambda(lambd=lambda x: x / 1e3))
        ]))

    train_loader = torch.utils.data.DataLoader(train_dst,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dst,
                                             batch_size=16,
                                             num_workers=4)
    TOTAL_ITERS = len(train_loader) * 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                       T_max=TOTAL_ITERS)

    confusion_matrix = metrics.ConfusionMatrix(num_classes=13,
                                               ignore_idx=255,
                                               attach_to=0)
    metric = metrics.MetricCompose(
        metric_dict={
            'acc': metrics.Accuracy(attach_to=0),
            'cm': confusion_matrix,
            'mIoU': metrics.mIoU(confusion_matrix),
            'rmse': metrics.RootMeanSquaredError(attach_to=1)
        })
    evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader,
                                                metric=metric,
                                                progress=False)

    task = [
        kamal.tasks.StandardTask.distillation(attach_to=[0, 0]),
        kamal.tasks.StandardTask.monocular_depth(attach_to=[1, 1])
    ]
    trainer = engine.trainer.KDTrainer(
        logger=kamal.utils.logger.get_logger('nyuv2_simple_kd'),
        tb_writer=SummaryWriter(log_dir='run/nyuv2_simple_kd-%s' %
                                (time.asctime().replace(' ', '_'))))
    trainer.setup(student=model,
                  teacher=[seg_teacher, depth_teacher],
                  task=task,
                  dataloader=train_loader,
                  optimizer=optim,
                  device=device)
    trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10),
                         callbacks=callbacks.MetricsLogging(keys=('total_loss',
                                                                  'lr')))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_STEP,
        callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
    trainer.add_callback(
        engine.DefaultEvents.AFTER_EPOCH,
        callbacks=[
            callbacks.EvalAndCkpt(model=model,
                                  evaluator=evaluator,
                                  metric_name='rmse',
                                  metric_mode='min',
                                  ckpt_prefix='nyuv2_simple_kd'),
            callbacks.VisualizeSegmentation(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                attach_to=0,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
            callbacks.VisualizeDepth(
                model=model,
                dataset=val_dst,
                idx_list_or_num_vis=10,
                max_depth=10,
                attach_to=1,
                normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225],
                                                  reverse=True),
            ),
        ])
    trainer.run(start_iter=0, max_iter=TOTAL_ITERS)