Ejemplo n.º 1
0
def validate(model, device, args, *, all_iters=None, arch_loader=None):
    assert arch_loader is not None

    objs = AvgrageMeter()
    top1 = AvgrageMeter()
    top5 = AvgrageMeter()

    loss_function = args.loss_function
    val_dataloader = args.val_dataloader

    model.eval()
    # model.apply(bn_calibration_init)

    max_val_iters = 0
    t1 = time.time()

    result_dict = {}

    arch_dict = arch_loader.get_arch_dict()

    base_model = mutableResNet20(10).cuda()

    with torch.no_grad():
        for key, value in arch_dict.items():  # 每一个网络
            max_val_iters += 1
            # print('\r ', key, ' iter:', max_val_iters, end='')

            for data, target in val_dataloader:  # 过一遍数据集
                target = target.type(torch.LongTensor)
                data, target = data.to(device), target.to(device)

                output = model(data, value["arch"])

                prec1, prec5 = accuracy(output, target, topk=(1, 5))

                print("acc1: ", prec1.item())
                n = data.size(0)

                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

            tmp_dict = {}
            tmp_dict['arch'] = value['arch']
            tmp_dict['acc'] = top1.avg

            result_dict[key] = tmp_dict

    with open("acc_result.json", "w") as f:
        json.dump(result_dict, f)
Ejemplo n.º 2
0
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    val_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root="./data",
        train=False,
        transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True)

    print('load data successfully')

    model = mutableResNet20(10)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    model = model.to(device)
    print("load model successfully")

    all_iters = 0
    print('load from latest checkpoint')
    lastest_model, iters = get_lastest_model()
    if lastest_model is not None:
        all_iters = iters
        checkpoint = torch.load(lastest_model,
                                map_location=None if use_gpu else 'cpu')
        model.load_state_dict(checkpoint['state_dict'], strict=True)

    # 参数设置
    args.loss_function = loss_function
    args.val_dataloader = val_loader

    print("start to validate model")

    validate(model, device, args, all_iters=all_iters, arch_loader=arch_loader)
Ejemplo n.º 3
0
            layer_i[2].shortcut2[0].weight[:cand[idx +
                                                 6], :cand[idx +
                                                           4], :, :].data,
            (-1, ))

        idx += 6

        arch_vector += [
            torch.cat([
                conv1_1, conv1_2, conv2_1, conv2_2, conv3_1, conv3_2,
                shortcut1, shortcut2, shortcut3
            ],
                      dim=0)
        ]

    return torch.cat(arch_vector, dim=0)


def generate_angle(b_model, t_model, candidate):
    vec1 = generate_arch_vector(b_model, candidate)
    vec2 = generate_arch_vector(t_model, candidate)
    cos = nn.CosineSimilarity(dim=0)
    angle = torch.acos(cos(vec1, vec2))
    return angle


if __name__ == "__main__":
    m1 = mutableResNet20()
    m2 = mutableResNet20()
    print(generate_angle(m1, m2, arc_representation))
Ejemplo n.º 4
0
def main():
    args = get_args()
    num_gpus = torch.cuda.device_count()
    args.gpu = args.local_rank % num_gpus
    torch.cuda.set_device(args.gpu)

    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.world_size = torch.distributed.get_world_size()
    args.batch_size = args.batch_size // args.world_size

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m-%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}-{:02}-{:02}-{:.3f}'.format(
            local_time.tm_year % 2000, local_time.tm_mon, local_time.tm_mday,
            t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    train_loader = get_train_loader(args.batch_size, args.local_rank,
                                    args.num_workers, args.total_iters)

    val_loader = get_val_loader(args.batch_size, args.num_workers)

    model = mutableResNet20()

    logging.info('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        # model = nn.DataParallel(model)
        model = model.cuda(args.gpu)
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        loss_function = criterion_smooth.cuda()
    else:
        loss_function = criterion_smooth

    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)

    all_iters = 0

    if args.auto_continue:  # 自动进行??
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, args, all_iters=all_iters, arch_loader=arch_loader)
        exit(0)

    # warmup weights
    if args.warmup > 0:
        logging.info("begin warmup weights")
        while all_iters < args.warmup:
            all_iters = train_supernet(model,
                                       args,
                                       bn_process=False,
                                       all_iters=all_iters)

        validate(model, args, all_iters=all_iters, arch_loader=arch_loader)

    while all_iters < args.total_iters:
        logging.info("=" * 50)
        all_iters = train_subnet(model,
                                 args,
                                 bn_process=False,
                                 all_iters=all_iters,
                                 arch_loader=arch_loader)

        if all_iters % 200 == 0 and args.local_rank == 0:
            logging.info("validate iter {}".format(all_iters))

            validate(model, args, all_iters=all_iters, arch_loader=arch_loader)
Ejemplo n.º 5
0
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    train_dataset, val_dataset = get_dataset('cifar100')

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=16,
                                               pin_memory=True)
    # train_dataprovider = DataIterator(train_loader)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=200,
                                             shuffle=False,
                                             num_workers=12,
                                             pin_memory=True)

    # val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    model = mutableResNet20()

    print('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda step: (1.0 - step / args.total_iters)
        if step <= args.total_iters else 0,
        last_epoch=-1)

    model = model.to(device)

    # dp_model = torch.nn.parallel.DistributedDataParallel(model)

    all_iters = 0
    if args.auto_continue:  # 自动进行??
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader
    # args.train_dataprovider = train_dataprovider
    # args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model,
                          device,
                          args,
                          val_interval=args.val_interval,
                          bn_process=False,
                          all_iters=all_iters,
                          arch_loader=arch_loader,
                          arch_batch=args.arch_batch)
Ejemplo n.º 6
0
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m-%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}-{:02}-{:02}-{:.3f}'.format(
            local_time.tm_year % 2000, local_time.tm_mon, local_time.tm_mday,
            t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    kwargs = {'num_workers': 4, 'pin_memory': True}

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root="./data",
        train=False,
        transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    model = mutableResNet20(num_classes=10)
    base_model = copy.deepcopy(model)

    logging.info('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
        base_model.cuda()
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
    #                                               lambda step: (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, T_max=200)

    model = model.to(device)

    all_iters = 0

    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)
        exit(0)

    # warmup weights
    if args.warmup is not None:
        logging.info("begin warmup weights")
        while all_iters < args.warmup:
            all_iters = train_supernet(model,
                                       device,
                                       args,
                                       bn_process=False,
                                       all_iters=all_iters)

        validate(model,
                 device,
                 args,
                 all_iters=all_iters,
                 arch_loader=arch_loader)

    while all_iters < args.total_iters:
        all_iters = train_subnet(model,
                                 base_model,
                                 device,
                                 args,
                                 bn_process=False,
                                 all_iters=all_iters,
                                 arch_loader=arch_loader)
        logging.info("validate iter {}".format(all_iters))

        if all_iters % 9 == 0:
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)

    validate(model, device, args, all_iters=all_iters, arch_loader=arch_loader)