def main(): car_train_dst = vision.datasets.StanfordCars('../data/StanfordCars', split='train') car_val_dst = vision.datasets.StanfordCars('../data/StanfordCars', split='test') aircraft_train_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft', split='trainval') aircraft_val_dst = vision.datasets.FGVCAircraft('../data/FGVCAircraft', split='test') car_teacher = vision.models.classification.resnet18(num_classes=196, pretrained=False) aircraft_teacher = vision.models.classification.resnet18(num_classes=102, pretrained=False) student = vision.models.classification.resnet18(num_classes=196 + 102, pretrained=False) car_teacher.load_state_dict(torch.load(args.car_ckpt)) aircraft_teacher.load_state_dict(torch.load(args.aircraft_ckpt)) train_transform = sT.Compose([ sT.RandomResizedCrop(224), sT.RandomHorizontalFlip(), sT.ToTensor(), sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = sT.Compose([ sT.Resize(256), sT.CenterCrop(224), sT.ToTensor(), sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) car_train_dst.transform = aircraft_train_dst.transform = train_transform car_val_dst.transform = aircraft_val_dst.transform = val_transform car_metric = metrics.MetricCompose( metric_dict={ 'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t)) }) aircraft_metric = metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:], t)) }) train_dst = torch.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst]) train_loader = torch.utils.data.DataLoader(train_dst, batch_size=32, shuffle=True, num_workers=4) car_loader = torch.utils.data.DataLoader(car_val_dst, batch_size=32, shuffle=False, num_workers=4) aircraft_loader = torch.utils.data.DataLoader(aircraft_val_dst, batch_size=32, shuffle=False, num_workers=4) car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric) aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric) if args.ckpt is not None: student.load_state_dict(torch.load(args.ckpt)) print("Load student model from %s" % args.ckpt) if args.test_only: results_car = car_evaluator.eval(student) results_aircraft = aircraft_evaluator.eval(student) print("Stanford Cars: %s" % (results_car)) print("FGVC Aircraft: %s" % (results_aircraft)) return TOTAL_ITERS = len(train_loader) * 100 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') optim = torch.optim.Adam(student.parameters(), lr=args.lr, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=TOTAL_ITERS) trainer = amalgamation.CommonFeatureAmalgamator( logger=utils.logger.get_logger('cfl'), tb_writer=SummaryWriter(log_dir='run/cfl-%s' % (time.asctime().replace(' ', '_')))) trainer.add_callback( engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH, callbacks=[ callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='cfl_car'), callbacks.EvalAndCkpt( model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='cfl_aircraft'), ]) trainer.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) layer_groups = [(student.fc, car_teacher.fc, aircraft_teacher.fc)] layer_channels = [(512, 512, 512)] trainer.setup(student=student, teachers=[car_teacher, aircraft_teacher], layer_groups=layer_groups, layer_channels=layer_channels, dataloader=train_loader, optimizer=optim, device=device, on_layer_input=True, weights=[1., 10., 10.]) trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main(): # PyTorch Part num_classes = 11 model = vision.models.segmentation.deeplabv3_resnet50(num_classes=num_classes, pretrained_backbone=True) train_dst = vision.datasets.CamVid( 'data/CamVid11', split='trainval', transforms=sT.Compose([ sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)), sT.Sync( sT.RandomRotation(5), sT.RandomRotation(5)), sT.Multi( sT.ColorJitter(0.2, 0.2, 0.2), None), sT.Sync( sT.RandomCrop(240), sT.RandomCrop(240)), sT.Sync( sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip() ), sT.Multi( sT.ToTensor(), sT.ToTensor( normalize=False, dtype=torch.long) ), sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) ) ]) ) val_dst = vision.datasets.CamVid( 'data/CamVid11', split='test', transforms=sT.Compose([ sT.Multi( sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST)), sT.Multi( sT.ToTensor(), sT.ToTensor( normalize=False, dtype=torch.long ) ), sT.Multi( sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), sT.Lambda(lambd=lambda x: x.squeeze()) ) ]) ) train_loader = torch.utils.data.DataLoader( train_dst, batch_size=16, shuffle=True, num_workers=4 ) val_loader = torch.utils.data.DataLoader( val_dst, batch_size=16, num_workers=4 ) TOTAL_ITERS=len(train_loader) * 200 device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) optim = torch.optim.SGD( model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4 ) sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) # KAE Part metric = kamal.tasks.StandardMetrics.segmentation(num_classes=num_classes) evaluator = engine.evaluator.BasicEvaluator( dataloader=val_loader, metric=metric, progress=False ) task = kamal.tasks.StandardTask.segmentation() trainer = engine.trainer.BasicTrainer( logger=kamal.utils.logger.get_logger('camvid_seg_deeplab'), tb_writer=SummaryWriter( log_dir='run/camvid_seg_deeplab-%s'%( time.asctime().replace( ' ', '_' ) ) ) ) trainer.setup( model=model, task=task, dataloader=train_loader, optimizer=optim, device=device ) trainer.add_callback( engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr'))) trainer.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) trainer.add_callback( engine.DefaultEvents.AFTER_EPOCH, callbacks=[ callbacks.EvalAndCkpt(model=model, evaluator=evaluator, metric_name='miou', ckpt_prefix='camvid_seg_deeplabv3_resnet50'), callbacks.VisualizeSegmentation( model=model, dataset=val_dst, idx_list_or_num_vis=10, normalizer=kamal.utils.Normalizer( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], reverse=True), )]) #import matplotlib.pyplot as plt #lr_finder = kamal.engine.lr_finder.LRFinder() #best_lr = lr_finder.find( optim, model, trainer, lr_range=[1e-8, 1.0], max_iter=100, smooth_momentum=0.9 ) #fig = lr_finder.plot(polyfit=4) #plt.savefig('lr_finder_deeplab.png') #lr_finder.adjust_learning_rate(optim, best_lr) trainer.run( start_iter=0, max_iter=TOTAL_ITERS )
def main(): # PyTorch Part model = vision.models.segmentation.deeplabv3_resnet50( num_classes=1, pretrained_backbone=True) train_dst = vision.datasets.NYUv2( 'data/NYUv2', split='train', target_type='depth', transforms=sT.Compose([ sT.Multi(sT.Resize(240), sT.Resize(240)), sT.Sync(sT.RandomRotation(5), sT.RandomRotation(5)), sT.Multi(sT.ColorJitter(0.2, 0.2, 0.2), None), sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240)), sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip()), sT.Multi(sT.ToTensor(), sT.ToTensor(normalize=False, dtype=torch.float)), sT.Multi( sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), sT.Lambda(lambda x: x / 1000)) ])) val_dst = vision.datasets.NYUv2( 'data/NYUv2', split='test', target_type='depth', transforms=sT.Compose([ sT.Multi(sT.Resize(240), sT.Resize(240)), sT.Multi(sT.ToTensor(), sT.ToTensor(normalize=False, dtype=torch.float)), sT.Multi( sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), sT.Lambda(lambda x: x / 1000)) ])) train_loader = torch.utils.data.DataLoader(train_dst, batch_size=16, shuffle=True, num_workers=4) val_loader = torch.utils.data.DataLoader(val_dst, batch_size=16, num_workers=4) TOTAL_ITERS = len(train_loader) * 200 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=TOTAL_ITERS) # KAE Part metric = kamal.tasks.StandardMetrics.monocular_depth() evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader, metric=metric, progress=False) task = kamal.tasks.StandardTask.monocular_depth() trainer = engine.trainer.BasicTrainer( logger=kamal.utils.logger.get_logger('nyuv2_depth_deeplab'), tb_writer=SummaryWriter(log_dir='run/nyuv2_depth_deeplab-%s' % (time.asctime().replace(' ', '_')))) trainer.setup(model=model, task=task, dataloader=train_loader, optimizer=optim, device=device) trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr'))) trainer.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH, callbacks=[ callbacks.EvalAndCkpt(model=model, evaluator=evaluator, metric_name='rmse', metric_mode='min', ckpt_prefix='nyuv2_depth'), callbacks.VisualizeDepth( model=model, dataset=val_dst, idx_list_or_num_vis=10, max_depth=10, normalizer=kamal.utils.Normalizer( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], reverse=True), ) ]) trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main(): # Seg + Depth model = MultiTaskSegNet(out_channel_list=[13, 1]) seg_teacher = vision.models.segmentation.segnet_vgg16_bn( num_classes=13, pretrained_backbone=True) depth_teacher = vision.models.segmentation.segnet_vgg16_bn( num_classes=1, pretrained_backbone=True) seg_teacher.load_state_dict(torch.load(args.seg_ckpt)) depth_teacher.load_state_dict(torch.load(args.depth_ckpt)) seg_train_dst = vision.datasets.NYUv2('../data/NYUv2', split='train', target_type='semantic') seg_val_dst = vision.datasets.NYUv2('../data/NYUv2', split='test', target_type='semantic') depth_train_dst = vision.datasets.NYUv2('../data/NYUv2', split='train', target_type='depth') depth_val_dst = vision.datasets.NYUv2('../data/NYUv2', split='test', target_type='depth') train_dst = vision.datasets.LabelConcatDataset( datasets=[seg_train_dst, depth_train_dst], transforms=sT.Compose([ sT.Multi(sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST), sT.Resize(240)), sT.Sync(sT.RandomCrop(240), sT.RandomCrop(240), sT.RandomCrop(240)), sT.Sync(sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip(), sT.RandomHorizontalFlip()), sT.Multi(sT.ToTensor(), sT.ToTensor(normalize=False, dtype=torch.long), sT.ToTensor(normalize=False, dtype=torch.float)), sT.Multi( sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), sT.Lambda(lambd=lambda x: x.squeeze()), sT.Lambda(lambd=lambda x: x / 1e3)) ])) val_dst = vision.datasets.LabelConcatDataset( datasets=[seg_val_dst, depth_val_dst], transforms=sT.Compose([ sT.Multi(sT.Resize(240), sT.Resize(240, interpolation=Image.NEAREST), sT.Resize(240)), sT.Multi(sT.ToTensor(), sT.ToTensor(normalize=False, dtype=torch.long), sT.ToTensor(normalize=False, dtype=torch.float)), sT.Multi( sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), sT.Lambda(lambd=lambda x: x.squeeze()), sT.Lambda(lambd=lambda x: x / 1e3)) ])) train_loader = torch.utils.data.DataLoader(train_dst, batch_size=16, shuffle=True, num_workers=4) val_loader = torch.utils.data.DataLoader(val_dst, batch_size=16, num_workers=4) TOTAL_ITERS = len(train_loader) * 200 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=TOTAL_ITERS) confusion_matrix = metrics.ConfusionMatrix(num_classes=13, ignore_idx=255, attach_to=0) metric = metrics.MetricCompose( metric_dict={ 'acc': metrics.Accuracy(attach_to=0), 'cm': confusion_matrix, 'mIoU': metrics.mIoU(confusion_matrix), 'rmse': metrics.RootMeanSquaredError(attach_to=1) }) evaluator = engine.evaluator.BasicEvaluator(dataloader=val_loader, metric=metric, progress=False) task = [ kamal.tasks.StandardTask.distillation(attach_to=[0, 0]), kamal.tasks.StandardTask.monocular_depth(attach_to=[1, 1]) ] trainer = engine.trainer.KDTrainer( logger=kamal.utils.logger.get_logger('nyuv2_simple_kd'), tb_writer=SummaryWriter(log_dir='run/nyuv2_simple_kd-%s' % (time.asctime().replace(' ', '_')))) trainer.setup(student=model, teacher=[seg_teacher, depth_teacher], task=task, dataloader=train_loader, optimizer=optim, device=device) trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr'))) trainer.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) trainer.add_callback( engine.DefaultEvents.AFTER_EPOCH, callbacks=[ callbacks.EvalAndCkpt(model=model, evaluator=evaluator, metric_name='rmse', metric_mode='min', ckpt_prefix='nyuv2_simple_kd'), callbacks.VisualizeSegmentation( model=model, dataset=val_dst, idx_list_or_num_vis=10, attach_to=0, normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], reverse=True), ), callbacks.VisualizeDepth( model=model, dataset=val_dst, idx_list_or_num_vis=10, max_depth=10, attach_to=1, normalizer=kamal.utils.Normalizer(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], reverse=True), ), ]) trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--lr', type=float, default=0.05) parser.add_argument('--t_model_path', type=str, default='../pretrained/cifar100_wrn_40_2.pth') parser.add_argument('--T', '--temperature', type=float, default=4.0, help='temperature for KD distillation') parser.add_argument('-r', '--gamma', type=float, default=1) parser.add_argument('-a', '--alpha', type=float, default=None) parser.add_argument('-b', '--beta', type=float, default=None) parser.add_argument('--epochs', type=int, default=240) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') parser.add_argument('--distill', type=str, default='kd', choices=[ 'kd', 'hint', 'attention', 'sp', 'cc', 'vid', 'svd', 'pkt', 'nst', 'rkd']) # dataset parser.add_argument('--data_root', type=str, default='../data/torchdata') parser.add_argument('--img_size', type=int, default=32, help='image size of datasets') # hint layer parser.add_argument('--hint_layer', default=2, type=int, choices=[0, 1, 2, 3, 4]) # cc embed dim parser.add_argument('--embed_dim', default=128, type=int, help='feature dimension') args = parser.parse_args() # prepare data cifar100_train = CIFAR100(args.data_root, train=True, download=True, transform=T.Compose([ T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])]) ) train_loader = torch.utils.data.DataLoader( cifar100_train, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True ) val_loader = torch.utils.data.DataLoader( CIFAR100(args.data_root, train=False, download=True, transform=T.Compose([ T.ToTensor(), T.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])]) ), batch_size=args.batch_size, num_workers=args.num_workers ) # prepare model teacher = vision.models.classification.cifar.wrn.wrn_40_2(num_classes=100) student = vision.models.classification.cifar.wrn.wrn_16_2(num_classes=100) teacher.load_state_dict(torch.load(args.t_model_path)) print('[!] teacher loads weights from %s' % (args.t_model_path)) # Teacher eval evaluator = engine.evaluator.BasicEvaluator(val_loader, metric=metrics.StandardTaskMetrics.classification()) teacher_scores = evaluator.eval(teacher) print('[TEACHER] Acc=%.4f' % (teacher_scores['acc'])) # hook module feature out_flags = [True, True, True, True, False] tea_hooks = [] tea_layers = [teacher.conv1, teacher.block1, teacher.block2, teacher.block3, teacher.fc] for module in tea_layers: hookfeat = engine.hooks.FeatureHook(module) hookfeat.register() tea_hooks.append(hookfeat) stu_hooks = [] stu_layers = [student.conv1, student.block1, student.block2, student.block3, student.fc] for module in stu_layers: hookfeat = engine.hooks.FeatureHook(module) hookfeat.register() stu_hooks.append(hookfeat) # distiller setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger = utils.logger.get_logger('distill_%s' % (args.distill)) tb_writer = SummaryWriter(log_dir='run/distill_%s-%s' % (args.distill, time.asctime().replace(' ', '_'))) if args.distill == 'kd': optimizer = torch.optim.SGD(student.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.KDDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, device=device) if args.distill == 'attention': optimizer = torch.optim.SGD(student.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.AttentionDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'nst': optimizer = torch.optim.SGD(student.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.NSTDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'sp': optimizer = torch.optim.SGD(student.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.SPDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'rkd': optimizer = torch.optim.SGD(student.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.RKDDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'pkt': optimizer = torch.optim.SGD(student.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.PKTDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'svd': optimizer = torch.optim.SGD(student.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.SVDDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'vid': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = torch.randn(1, 3, args.img_size, args.img_size) data = data.to(device) student, teacher = student.to(device), teacher.to(device) teacher(data) student(data) s_n = [f.feat_out.shape[1] if flag else f.feat_in.shape[1] for (f, flag) in zip(stu_hooks[1:-1], out_flags)] t_n = [f.feat_out.shape[1] if flag else f.feat_in.shape[1] for (f, flag) in zip(tea_hooks[1:-1], out_flags)] train_list = nn.ModuleList([student]) VIDRegressor_l = nn.ModuleList() for s, t in zip(s_n, t_n): vid_r = slim.VIDRegressor(s, t, t) VIDRegressor_l.append(vid_r) train_list.append(vid_r) optimizer = torch.optim.SGD(train_list.parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.VIDDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, regressor_l=VIDRegressor_l, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'hint': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = torch.randn(1, 3, args.img_size, args.img_size) data = data.to(device) student, teacher = student.to(device), teacher.to(device) teacher(data), student(data) feat_t = [f.feat_out if flag else f.feat_in for ( f, flag) in zip(tea_hooks, out_flags)] feat_s = [f.feat_out if flag else f.feat_in for ( f, flag) in zip(stu_hooks, out_flags)] fitnet = slim.Regressor( feat_s[args.hint_layer].shape, feat_t[args.hint_layer].shape) optimizer = torch.optim.SGD(nn.ModuleList( [student, fitnet]).parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.HintDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, regressor=fitnet, hint_layer=args.hint_layer, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) if args.distill == 'cc': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = torch.randn(1, 3, args.img_size, args.img_size) data = data.to(device) student, teacher = student.to(device), teacher.to(device) student, teacher = student.to(device), teacher.to(device) teacher(data), student(data) feat_t = [f.feat_out if flag else f.feat_in for ( f, flag) in zip(tea_hooks, out_flags)] feat_s = [f.feat_out if flag else f.feat_in for ( f, flag) in zip(stu_hooks, out_flags)] embed_s = slim.LinearEmbed(feat_s[-1].shape[1], args.embed_dim) embed_t = slim.LinearEmbed(feat_t[-1].shape[1], args.embed_dim) optimizer = torch.optim.SGD(nn.ModuleList([student, embed_s, embed_t]).parameters( ), lr=args.lr, momentum=0.9, weight_decay=5e-4) distiller = slim.CCDistiller(logger, tb_writer) distiller.setup(student=student, teacher=teacher, dataloader=train_loader, optimizer=optimizer, embed_s=embed_s, embed_t=embed_t, T=args.T, gamma=args.gamma, alpha=args.alpha, beta=args.beta, stu_hooks=stu_hooks, tea_hooks=tea_hooks, out_flags=out_flags, device=device) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, len(train_loader)*args.epochs) distiller.add_callback( engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kld', 'loss_ce', 'loss_additional', 'lr'))) distiller.add_callback( engine.DefaultEvents.AFTER_EPOCH, callbacks=callbacks.EvalAndCkpt(model=student, evaluator=evaluator, metric_name='acc', ckpt_prefix=args.distill) ) distiller.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[scheduler])) distiller.run(start_iter=0, max_iter=len(train_loader)*args.epochs)
def main(): # Pytorch Part if args.dataset == 'stanford_dogs': num_classes = 120 train_dst = vision.datasets.StanfordDogs('data/StanfordDogs', split='train') val_dst = vision.datasets.StanfordDogs('data/StanfordDogs', split='test') elif args.dataset == 'cub200': num_classes = 200 train_dst = vision.datasets.CUB200('data/CUB200', split='train') val_dst = vision.datasets.CUB200('data/CUB200', split='test') elif args.dataset == 'fgvc_aircraft': num_classes = 102 train_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/', split='trainval') val_dst = vision.datasets.FGVCAircraft('data/FGVCAircraft/', split='test') elif args.dataset == 'stanford_cars': num_classes = 196 train_dst = vision.datasets.StanfordCars('data/StanfordCars/', split='train') val_dst = vision.datasets.StanfordCars('data/StanfordCars/', split='test') else: raise NotImplementedError model = vision.models.classification.resnet18(num_classes=num_classes, pretrained=args.pretrained) train_dst.transform = sT.Compose([ sT.RandomResizedCrop(224), sT.RandomHorizontalFlip(), sT.ToTensor(), sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_dst.transform = sT.Compose([ sT.Resize(256), sT.CenterCrop(224), sT.ToTensor(), sT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_loader = torch.utils.data.DataLoader(train_dst, batch_size=32, shuffle=True, num_workers=4) val_loader = torch.utils.data.DataLoader(val_dst, batch_size=32, num_workers=4) TOTAL_ITERS = len(train_loader) * args.epochs device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=TOTAL_ITERS) # KAE Part # Predefined task & metrics task = kamal.tasks.StandardTask.classification() metric = kamal.tasks.StandardMetrics.classification() # Evaluator and Trainer evaluator = engine.evaluator.BasicEvaluator(val_loader, metric=metric, progress=True) trainer = engine.trainer.BasicTrainer( logger=kamal.utils.logger.get_logger(args.dataset), tb_writer=SummaryWriter( log_dir='run/%s-%s' % (args.dataset, time.asctime().replace(' ', '_')))) # setup trainer trainer.setup(model=model, task=task, dataloader=train_loader, optimizer=optim, device=device) trainer.add_callback(engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr'))) trainer.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) ckpt_callback = trainer.add_callback(engine.DefaultEvents.AFTER_EPOCH, callbacks=callbacks.EvalAndCkpt( model=model, evaluator=evaluator, metric_name='acc', ckpt_prefix=args.dataset)) trainer.run(start_iter=0, max_iter=TOTAL_ITERS) ckpt_callback.callback.final_ckpt(ckpt_dir='pretrained', add_md5=True)
def main(): car_train_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='train') car_val_dst = vision.datasets.StanfordCars( '../data/StanfordCars', split='test') aircraft_train_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='trainval') aircraft_val_dst = vision.datasets.FGVCAircraft( '../data/FGVCAircraft', split='test') dog_train_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='train') dog_val_dst = vision.datasets.StanfordDogs( '../data/StanfordDogs', split='test') cub_train_dst = vision.datasets.CUB200( '../data/CUB200', split='train') cub_val_dst = vision.datasets.CUB200( '../data/CUB200', split='test') #car_teacher = vision.models.classification.resnet18( num_classes=196, pretrained=False ) #aircraft_teacher = vision.models.classification.resnet18( num_classes=102, pretrained=False ) #dog_teacher = vision.models.classification.resnet18( num_classes=120, pretrained=False ) #cub_teacher = vision.models.classification.resnet18( num_classes=200, pretrained=False ) student = vision.models.classification.resnet18( num_classes=196+102+120+200, pretrained=False ) #car_teacher.load_state_dict( torch.load( args.car_ckpt ) ) #aircraft_teacher.load_state_dict( torch.load( args.aircraft_ckpt ) ) #dog_teacher.load_state_dict( torch.load( args.dog_ckpt ) ) #cub_teacher.load_state_dict( torch.load( args.cub_ckpt ) ) train_transform = sT.Compose( [ sT.RandomResizedCrop(224), sT.RandomHorizontalFlip(), sT.ToTensor(), sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] ) val_transform = sT.Compose( [ sT.Resize(256), sT.CenterCrop( 224 ), sT.ToTensor(), sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] ) cub_train_dst.transform = dog_train_dst.transform = car_train_dst.transform = aircraft_train_dst.transform = train_transform cub_val_dst.transform = dog_val_dst.transform = car_val_dst.transform = aircraft_val_dst.transform = val_transform aircraft_train_dst.target_transform = lambda t: t+196 dog_train_dst.target_transform = lambda t: t+196+102 cub_train_dst.target_transform = lambda t: t+196+102+120 car_metric = metrics.MetricCompose(metric_dict={ 'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196],t) ) }) aircraft_metric = metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+102],t) ) }) dog_metric = metrics.MetricCompose(metric_dict={ 'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102:196+102+120],t) ) }) cub_metric = metrics.MetricCompose(metric_dict={ 'cub_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+102+120:196+102+120+200],t) ) }) train_dst = torch.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst, dog_train_dst, cub_train_dst] ) train_loader = torch.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=4 ) car_loader = torch.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=2 ) aircraft_loader = torch.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=2 ) dog_loader = torch.utils.data.DataLoader( dog_val_dst, batch_size=32, shuffle=False, num_workers=2 ) cub_loader = torch.utils.data.DataLoader( cub_val_dst, batch_size=32, shuffle=False, num_workers=2 ) car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric ) cub_evaluator = engine.evaluator.BasicEvaluator( cub_loader, cub_metric ) if args.ckpt is not None: student.load_state_dict( torch.load( args.ckpt ) ) print("Load student model from %s"%args.ckpt) if args.test_only: results_car = car_evaluator.eval( student ) results_aircraft = aircraft_evaluator.eval( student ) results_dog = dog_evaluator.eval( student ) results_cub = cub_evaluator.eval( student ) print("Stanford Cars: %s"%( results_car )) print("FGVC Aircraft: %s"%( results_aircraft )) print("Stanford Dogs: %s"%( results_dog )) print("CUB200: %s"%( results_cub )) return TOTAL_ITERS=len(train_loader) * 100 device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) optim = torch.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) task = tasks.StandardTask.classification() trainer = engine.trainer.BasicTrainer( logger=utils.logger.get_logger('scratch-4'), tb_writer=SummaryWriter( log_dir='run/scratch-4-%s'%( time.asctime().replace( ' ', '_' ) ) ) ) trainer.add_callback( engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'lr'))) trainer.add_callback( engine.DefaultEvents.AFTER_EPOCH, callbacks=[ callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='cfl_car'), callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='cfl_aircraft'), callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='cfl_dog'), callbacks.EvalAndCkpt(model=student, evaluator=cub_evaluator, metric_name='cub_acc', ckpt_prefix='cfl_cub'), ] ) trainer.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) trainer.setup( model=student, task=task, dataloader=train_loader, optimizer=optim, device=device ) trainer.run(start_iter=0, max_iter=TOTAL_ITERS)