예제 #1
0
파일: cifar10.py 프로젝트: lilujunai/homura
def main():
    model = {
        "resnet20": resnet20,
        "wrn28_10": wrn28_10
    }[args.model](num_classes=10)
    weight_decay = {"resnet20": 1e-4, "wrn28_10": 5e-4}[args.model]
    lr_decay = {"resnet20": 0.1, "wrn28_10": 0.2}[args.model]
    train_loader, test_loader = cifar10_loaders(args.batch_size)
    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=weight_decay)
    scheduler = lr_scheduler.MultiStepLR([100, 150], gamma=lr_decay)
    tq = reporters.TQDMReporter(range(args.epochs), verb=True)
    c = [
        callbacks.AccuracyCallback(),
        callbacks.LossCallback(),
        reporters.IOReporter("."),
        reporters.TensorboardReporter("."),
        callbacks.WeightSave("."), tq
    ]

    with trainers.SupervisedTrainer(model,
                                    optimizer,
                                    F.cross_entropy,
                                    callbacks=c,
                                    scheduler=scheduler) as trainer:
        for _ in tq:
            trainer.train(train_loader)
            trainer.test(test_loader)
예제 #2
0
def main():
    train_loader, test_loader = cifar10_loaders(args.batch_size)

    if args.model == "resnet":
        model = resnet20()
    elif args.model == "senet":
        model = se_resnet20(num_classes=10, reduction=args.reduction)
    elif args.model == "gcn":
        model = resnet20_gcn()
    else:
        raise TypeError(f"{args.model} is not valid argument")

    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.StepLR(80, 0.1)
    tqdm_rep = reporter.TQDMReporter(range(args.epochs),
                                     callbacks=[callbacks.AccuracyCallback()],
                                     save_dir='logs/',
                                     report_freq=-1)
    # tb_rep = reporter.TensorboardReporter(callbacks=[callbacks.AccuracyCallback(), callbacks.LossCallback()], save_dir='logs/')
    trainer = Trainer(model,
                      optimizer,
                      F.cross_entropy,
                      scheduler=scheduler,
                      callbacks=tqdm_rep)
    for _ in tqdm_rep:
        trainer.train(train_loader)
        trainer.test(test_loader)
예제 #3
0
def main():
    train_loader, test_loader = cifar10_loaders(args.batch_size)
    print(test_loader.sampler)

    if args.baseline:
        model = resnet20()
    else:
        model = se_resnet20(num_classes=10, reduction=args.reduction)
    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.StepLR(80, 0.1)
    tqdm_rep = reporters.TQDMReporter(range(args.epochs), callbacks.AccuracyCallback())
    _callbacks = [tqdm_rep, callbacks.AccuracyCallback()]
    with Trainer(model, optimizer, F.cross_entropy, scheduler=scheduler, callbacks=_callbacks) as trainer:
        for _ in tqdm_rep:
            trainer.train(train_loader)
            trainer.test(test_loader)
예제 #4
0
def main(batch_size):
    layers = ["layer1.0.conv1", "layer2.0.conv1", "layer3.0.conv1", "fc"]
    train_loader, test_loader = cifar10_loaders(128)
    weight_save = callbacks.WeightSave("checkpoints")
    model = resnet20(num_classes=10)
    model2 = deepcopy(model)
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50)
    trainer = trainers.SupervisedTrainer(model,
                                         optimizer,
                                         F.cross_entropy,
                                         scheduler=scheduler,
                                         callbacks=weight_save,
                                         verb=False)
    for ep in trange(100, ncols=80):
        trainer.train(train_loader)

    hooks1 = [CCAHook(model, name, svd_device=args.device) for name in layers]
    hooks2 = [CCAHook(model2, name, svd_device=args.device) for name in layers]
    device = next(model.parameters()).device
    model2.to(device)
    input = hooks1[0].data(train_loader.dataset,
                           batch_size=batch_size).to(device)
    history = []

    def distance():
        model.eval()
        model2.eval()
        with torch.no_grad():
            model(input)
            model2(input)
        return [h1.distance(h2) for h1, h2 in zip(hooks1, hooks2)]

    # 0 and 99
    history.append(distance())

    # 29 and 99, ...
    for ep in (29, 49, 99):
        saved = torch.load(weight_save.save_path / f"{ep}.pkl")
        model2.load_state_dict(saved["model"])
        history.append(distance())
    plot(history, layers)
예제 #5
0
def main():
    model = MODELS[args.teacher_model](num_classes=10)
    train_loader, test_loader = cifar10_loaders(args.batch_size)
    weight_decay = 1e-4 if "resnet" in args.teacher_model else 5e-4
    lr_decay = 0.1 if "resnet" in args.teacher_model else 0.2
    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=weight_decay)
    scheduler = lr_scheduler.MultiStepLR([50, 80], gamma=lr_decay)

    trainer = trainers.SupervisedTrainer(model,
                                         optimizer,
                                         F.cross_entropy,
                                         scheduler=scheduler)
    trainer.logger.info("Train the teacher model!")
    for _ in trange(args.teacher_epochs, ncols=80):
        trainer.train(train_loader)
        trainer.test(test_loader)

    teacher_model = model.eval()

    weight_decay = 1e-4 if "resnet" in args.student_model else 5e-4
    lr_decay = 0.1 if "resnet" in args.student_model else 0.2
    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=weight_decay)
    scheduler = lr_scheduler.MultiStepLR([50, 80], gamma=lr_decay)
    model = MODELS[args.student_model](num_classes=10)

    c = [callbacks.AccuracyCallback(), callbacks.LossCallback(), kl_loss]
    with reporters.TQDMReporter(
            range(args.student_epochs),
            callbacks=c) as tq, reporters.TensorboardReporter(c) as tb:
        trainer = DistillationTrainer(model,
                                      optimizer,
                                      F.cross_entropy,
                                      callbacks=[tq, tb],
                                      scheduler=scheduler,
                                      teacher_model=teacher_model,
                                      temperature=args.temperature)
        trainer.logger.info("Train the student model!")
        for _ in tq:
            trainer.train(train_loader)
            trainer.test(test_loader)