コード例 #1
0
 mutator = SPOSSupernetTrainingMutator(model,
                                       flops_func=flops_func,
                                       flops_lb=290E6,
                                       flops_ub=360E6)
 criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
 optimizer = torch.optim.SGD(model.parameters(),
                             lr=args.learning_rate,
                             momentum=args.momentum,
                             weight_decay=args.weight_decay)
 scheduler = torch.optim.lr_scheduler.LambdaLR(
     optimizer,
     lambda step: (1.0 - step / args.epochs) if step <= args.epochs else 0,
     last_epoch=-1)
 train_loader = get_imagenet_iter_dali(
     "train",
     args.imagenet_dir,
     args.batch_size,
     args.workers,
     spos_preprocessing=args.spos_preprocessing)
 valid_loader = get_imagenet_iter_dali(
     "val",
     args.imagenet_dir,
     args.batch_size,
     args.workers,
     spos_preprocessing=args.spos_preprocessing)
 trainer = SPOSSupernetTrainer(model,
                               criterion,
                               accuracy,
                               optimizer,
                               args.epochs,
                               train_loader,
                               valid_loader,
コード例 #2
0
ファイル: tester.py プロジェクト: xxlya/COS598D_Assignment3
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    assert torch.cuda.is_available()

    model = ShuffleNetV2OneShot()
    criterion = CrossEntropyLabelSmooth(1000, 0.1)
    get_and_apply_next_architecture(model)
    model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint))
    model.cuda()

    train_loader = get_imagenet_iter_dali(
        "train",
        args.imagenet_dir,
        args.train_batch_size,
        args.workers,
        spos_preprocessing=args.spos_preprocessing,
        seed=args.seed,
        device_id=0)
    val_loader = get_imagenet_iter_dali(
        "val",
        args.imagenet_dir,
        args.test_batch_size,
        args.workers,
        spos_preprocessing=args.spos_preprocessing,
        shuffle=True,
        seed=args.seed,
        device_id=0)
    train_loader = cycle(train_loader)

    evaluate_acc(model, criterion, args, train_loader, val_loader)