Exemplo n.º 1
0
        elif params.teacher == "wrn":
            teacher_model = wrn.WideResNet(depth=28,
                                           num_classes=10,
                                           widen_factor=10,
                                           dropRate=0.3)
            teacher_checkpoint = 'experiments/base_wrn/best.pth.tar'
            teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "densenet":
            teacher_model = densenet.DenseNet(depth=100, growthRate=12)
            teacher_checkpoint = 'experiments/base_densenet/best.pth.tar'
            teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "resnext29":
            teacher_model = resnext.CifarResNeXt(cardinality=8,
                                                 depth=29,
                                                 num_classes=10)
            teacher_checkpoint = 'experiments/base_resnext29/best.pth.tar'
            teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "preresnet110":
            teacher_model = preresnet.PreResNet(depth=110, num_classes=10)
            teacher_checkpoint = 'experiments/base_preresnet110/best.pth.tar'
            teacher_model = nn.DataParallel(teacher_model).cuda()

        utils.load_checkpoint(teacher_checkpoint, teacher_model)

        # Train the model with KD
        logging.info("Experiment - model version: {}".format(
            params.model_version))
        logging.info("Starting training for {} epoch(s)".format(
Exemplo n.º 2
0
def main():
    # Load the parameters from json file
    args = parser.parse_args()
    json_path = os.path.join(args.model_dir, 'params.json')
    assert os.path.isfile(
        json_path), "No json configuration file found at {}".format(json_path)
    params = utils.Params(json_path)

    # Set the random seed for reproducible experiments
    random.seed(230)
    torch.manual_seed(230)
    np.random.seed(230)
    torch.cuda.manual_seed(230)
    warnings.filterwarnings("ignore")

    # Set the logger
    utils.set_logger(os.path.join(args.model_dir, 'train.log'))

    # Create the input data pipeline
    logging.info("Loading the datasets...")

    # fetch dataloaders, considering full-set vs. sub-set scenarios
    if params.subset_percent < 1.0:
        train_dl = data_loader.fetch_subset_dataloader('train', params)
    else:
        train_dl = data_loader.fetch_dataloader('train', params)

    dev_dl = data_loader.fetch_dataloader('dev', params)

    logging.info("- done.")
    """
    Load student and teacher model
    """
    if "distill" in params.model_version:

        # Specify the student models
        if params.model_version == "cnn_distill":  # 5-layers Plain CNN
            print("Student model: {}".format(params.model_version))
            model = net.Net(params).cuda()

        elif params.model_version == "shufflenet_v2_distill":
            print("Student model: {}".format(params.model_version))
            model = shufflenet.shufflenetv2(class_num=args.num_class).cuda()

        elif params.model_version == "mobilenet_v2_distill":
            print("Student model: {}".format(params.model_version))
            model = mobilenet.mobilenetv2(class_num=args.num_class).cuda()

        elif params.model_version == 'resnet18_distill':
            print("Student model: {}".format(params.model_version))
            model = resnet.ResNet18(num_classes=args.num_class).cuda()

        elif params.model_version == 'resnet50_distill':
            print("Student model: {}".format(params.model_version))
            model = resnet.ResNet50(num_classes=args.num_class).cuda()

        elif params.model_version == "alexnet_distill":
            print("Student model: {}".format(params.model_version))
            model = alexnet.alexnet(num_classes=args.num_class).cuda()

        elif params.model_version == "vgg19_distill":
            print("Student model: {}".format(params.model_version))
            model = models.vgg19_bn(num_classes=args.num_class).cuda()

        elif params.model_version == "googlenet_distill":
            print("Student model: {}".format(params.model_version))
            model = googlenet.GoogleNet(num_class=args.num_class).cuda()

        elif params.model_version == "resnext29_distill":
            print("Student model: {}".format(params.model_version))
            model = resnext.CifarResNeXt(cardinality=8,
                                         depth=29,
                                         num_classes=args.num_class).cuda()

        elif params.model_version == "densenet121_distill":
            print("Student model: {}".format(params.model_version))
            model = densenet.densenet121(num_class=args.num_class).cuda()

        # optimizer
        if params.model_version == "cnn_distill":
            optimizer = optim.Adam(model.parameters(),
                                   lr=params.learning_rate *
                                   (params.batch_size / 128))
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=params.learning_rate *
                                  (params.batch_size / 128),
                                  momentum=0.9,
                                  weight_decay=5e-4)

        iter_per_epoch = len(train_dl)
        warmup_scheduler = utils.WarmUpLR(
            optimizer, iter_per_epoch *
            args.warm)  # warmup the learning rate in the first epoch

        # specify loss function
        if args.self_training:
            print(
                '>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>self training>>>>>>>>>>>>>>>>>>>>>>>>>>>>>'
            )
            loss_fn_kd = loss_kd_self
        else:
            loss_fn_kd = loss_kd
        """ 
            Specify the pre-trained teacher models for knowledge distillation
            Checkpoints can be obtained by regular training or downloading our pretrained models
            For model which is pretrained in multi-GPU, use "nn.DaraParallel" to correctly load the model weights.
        """
        if params.teacher == "resnet18":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnet.ResNet18(num_classes=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet18/best.pth.tar'
            if args.pt_teacher:  # poorly-trained teacher for Defective KD experiments
                teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet18/0.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "alexnet":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = alexnet.alexnet(num_classes=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_alexnet/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "googlenet":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = googlenet.GoogleNet(num_class=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_googlenet/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "vgg19":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = models.vgg19_bn(num_classes=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_vgg19/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "resnet50":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnet.ResNet50(num_classes=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet50/best.pth.tar'
            if args.pt_teacher:  # poorly-trained teacher for Defective KD experiments
                teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet50/50.pth.tar'

        elif params.teacher == "resnet101":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnet.ResNet101(num_classes=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet101/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "densenet121":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = densenet.densenet121(
                num_class=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_densenet121/best.pth.tar'
            # teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "resnext29":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnext.CifarResNeXt(
                cardinality=8, depth=29, num_classes=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnext29/best.pth.tar'
            if args.pt_teacher:  # poorly-trained teacher for Defective KD experiments
                teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnext29/50.pth.tar'
                teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "mobilenet_v2":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = mobilenet.mobilenetv2(
                class_num=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_mobilenet_v2/best.pth.tar'

        elif params.teacher == "shufflenet_v2":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = shufflenet.shufflenetv2(
                class_num=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_shufflenet_v2/best.pth.tar'

        utils.load_checkpoint(teacher_checkpoint, teacher_model)

        # Train the model with KD
        logging.info("Starting training for {} epoch(s)".format(
            params.num_epochs))
        train_and_evaluate_kd(model, teacher_model, train_dl, dev_dl,
                              optimizer, loss_fn_kd, warmup_scheduler, params,
                              args, args.restore_file)

    # non-KD mode: regular training to obtain a baseline model
    else:
        print("Train base model")
        if params.model_version == "cnn":
            model = net.Net(params).cuda()

        elif params.model_version == "mobilenet_v2":
            print("model: {}".format(params.model_version))
            model = mobilenet.mobilenetv2(class_num=args.num_class).cuda()

        elif params.model_version == "shufflenet_v2":
            print("model: {}".format(params.model_version))
            model = shufflenet.shufflenetv2(class_num=args.num_class).cuda()

        elif params.model_version == "alexnet":
            print("model: {}".format(params.model_version))
            model = alexnet.alexnet(num_classes=args.num_class).cuda()

        elif params.model_version == "vgg19":
            print("model: {}".format(params.model_version))
            model = models.vgg19_bn(num_classes=args.num_class).cuda()

        elif params.model_version == "googlenet":
            print("model: {}".format(params.model_version))
            model = googlenet.GoogleNet(num_class=args.num_class).cuda()

        elif params.model_version == "densenet121":
            print("model: {}".format(params.model_version))
            model = densenet.densenet121(num_class=args.num_class).cuda()

        elif params.model_version == "resnet18":
            model = resnet.ResNet18(num_classes=args.num_class).cuda()

        elif params.model_version == "resnet50":
            model = resnet.ResNet50(num_classes=args.num_class).cuda()

        elif params.model_version == "resnet101":
            model = resnet.ResNet101(num_classes=args.num_class).cuda()

        elif params.model_version == "resnet152":
            model = resnet.ResNet152(num_classes=args.num_class).cuda()

        elif params.model_version == "resnext29":
            model = resnext.CifarResNeXt(cardinality=8,
                                         depth=29,
                                         num_classes=args.num_class).cuda()
            # model = nn.DataParallel(model).cuda()

        if args.regularization:
            print(
                ">>>>>>>>>>>>>>>>>>>>>>>>Loss of Regularization>>>>>>>>>>>>>>>>>>>>>>>>"
            )
            loss_fn = loss_kd_regularization
        elif args.label_smoothing:
            print(
                ">>>>>>>>>>>>>>>>>>>>>>>>Label Smoothing>>>>>>>>>>>>>>>>>>>>>>>>"
            )
            loss_fn = loss_label_smoothing
        else:
            print(
                ">>>>>>>>>>>>>>>>>>>>>>>>Normal Training>>>>>>>>>>>>>>>>>>>>>>>>"
            )
            loss_fn = nn.CrossEntropyLoss()
            if args.double_training:  # double training, compare to self-KD
                print(
                    ">>>>>>>>>>>>>>>>>>>>>>>>Double Training>>>>>>>>>>>>>>>>>>>>>>>>"
                )
                checkpoint = 'experiments/pretrained_teacher_models/base_' + str(
                    params.model_version) + '/best.pth.tar'
                utils.load_checkpoint(checkpoint, model)

        if params.model_version == "cnn":
            optimizer = optim.Adam(model.parameters(),
                                   lr=params.learning_rate *
                                   (params.batch_size / 128))
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=params.learning_rate *
                                  (params.batch_size / 128),
                                  momentum=0.9,
                                  weight_decay=5e-4)

        iter_per_epoch = len(train_dl)
        warmup_scheduler = utils.WarmUpLR(optimizer,
                                          iter_per_epoch * args.warm)

        # Train the model
        logging.info("Starting training for {} epoch(s)".format(
            params.num_epochs))
        train_and_evaluate(model, train_dl, dev_dl, optimizer, loss_fn, params,
                           args.model_dir, warmup_scheduler, args,
                           args.restore_file)