Example #1
0
def build_imagenet(model_state_dict, optimizer_state_dict, **kwargs):
    valid_ratio = kwargs.pop('valid_ratio', None)
    valid_num = kwargs.pop('valid_num', None)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,
            hue=0.1),
        transforms.ToTensor(),
        normalize,
    ])
    valid_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    if args.zip_file:
        logging.info('Loading data from zip file')
        traindir = os.path.join(args.data, 'train.zip')
        if args.lazy_load:
            data = utils.ZipDataset(traindir)
        else:
            logging.info('Loading data into memory')
            data = utils.InMemoryZipDataset(traindir, num_workers=args.num_workers)
    else:
        logging.info('Loading data from directory')
        traindir = os.path.join(args.data, 'train')
        if args.lazy_load:
            data = dset.ImageFolder(traindir)
        else:
            logging.info('Loading data into memory')
            data = utils.InMemoryDataset(traindir, num_workers=args.num_workers)
       
    num_data = len(data)
    indices = list(range(num_data))
    np.random.shuffle(indices)
    if valid_ratio is not None:
        split = int(np.floor(1 - valid_ratio * num_data))
        train_indices = sorted(indices[:split])
        valid_indices = sorted(indices[split:])
    else:
        assert valid_num is not None
        train_indices = sorted(indices[valid_num:])
        valid_indices = sorted(indices[:valid_num])

    train_data = utils.WrappedDataset(data, train_indices, train_transform)
    valid_data = utils.WrappedDataset(data, valid_indices, valid_transform)
    logging.info('train set = %d', len(train_data))
    logging.info('valid set = %d', len(valid_data))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices),
        pin_memory=True, num_workers=args.num_workers, drop_last=False)
    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.eval_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_indices),
        pin_memory=True, num_workers=args.num_workers, drop_last=False)
    
    model = NASNet(args.width_stages, args.n_cell_stages, args.stride_stages, args.dropout)
    model.init_model(args.model_init)
    model.set_bn_param(0.1, 0.001)
    logging.info("param size = %d", utils.count_parameters(model))

    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)

    if args.no_decay_keys:
        keys = args.no_decay_keys.split('#')
        net_params=[model.get_parameters(keys, mode='exclude'),
                    model.get_parameters(keys, mode='include')]
        optimizer = torch.optim.SGD([
            {'params': net_params[0], 'weight_decay': args.weight_decay},
            {'params': net_params[1], 'weight_decay': 0},],
            args.lr,
            momentum=0.9,
            nesterov=True,
        )
    else:
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=0.9,
            weight_decay=args.weight_decay,
            nesterov=True,
        )
    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)

    if torch.cuda.device_count() > 1:
        logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
        model = nn.DataParallel(model)
    model = model.cuda()

    train_criterion = utils.CrossEntropyLabelSmooth(1000, args.label_smooth).cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()
    return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
Example #2
0
def build_imagenet(model_config, model_state_dict, optimizer_state_dict, **kwargs):
    epoch = kwargs.pop('epoch')
    step = kwargs.pop('step')

    # build model
    logging.info('Building Model')
    model = NASNet.build_from_config(model_config)
    model.init_model(args.model_init)
    model.set_bn_param(model_config['bn']['momentum'], model_config['bn']['eps'])
    print(model.config)
    logging.info("param size = %d", utils.count_parameters(model))
    logging.info("multi adds = %fM", model.get_flops(torch.ones(1, 3, 224, 224).float())[0] / 1000000)
    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)

    if torch.cuda.device_count() > 1:
        logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
        model = nn.DataParallel(model)
    model = model.cuda()

    # build criterion
    logging.info('Building Criterion')
    train_criterion = utils.CrossEntropyLabelSmooth(1000, args.label_smooth).cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()

    # build optimizer
    logging.info('Building Optimizer')
    if args.no_decay_keys:
        keys = args.no_decay_keys.split('#')
        net_params=[model.module.get_parameters(keys, mode='exclude'),
                    model.module.get_parameters(keys, mode='include')]
        optimizer = torch.optim.SGD([
            {'params': net_params[0], 'weight_decay': args.weight_decay},
            {'params': net_params[1], 'weight_decay': 0},],
            args.lr,
            momentum=0.9,
            nesterov=True,
        )
    else:
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=0.9,
            weight_decay=args.weight_decay,
            nesterov=True,
        )
    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)

    # build data loader
    logging.info('Building Data')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,
            hue=0.1),
        transforms.ToTensor(),
        normalize,
    ])
    valid_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    if args.zip_file:
        logging.info('Loading data from zip file')
        traindir = os.path.join(args.data, 'train.zip')
        validdir = os.path.join(args.data, 'valid.zip')
        if args.lazy_load:
            train_data = utils.ZipDataset(traindir, train_transform)
            valid_data = utils.ZipDataset(validdir, valid_transform)
        else:
            logging.info('Loading data into memory')
            train_data = utils.InMemoryZipDataset(traindir, train_transform, num_workers=args.num_workers)
            valid_data = utils.InMemoryZipDataset(validdir, valid_transform, num_workers=args.num_workers)
    else:
        logging.info('Loading data from directory')
        traindir = os.path.join(args.data, 'train')
        validdir = os.path.join(args.data, 'val')
        if args.lazy_load:
            train_data = dset.ImageFolder(traindir, train_transform)
            valid_data = dset.ImageFolder(validdir, valid_transform)
        else:
            logging.info('Loading data into memory')
            train_data = utils.InMemoryDataset(traindir, train_transform, num_workers=args.num_workers)
            valid_data = utils.InMemoryDataset(validdir, valid_transform, num_workers=args.num_workers)
    
    logging.info('Found %d in training data', len(train_data))
    logging.info('Found %d in validation data', len(valid_data))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers)
    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.eval_batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers)
    
    # build lr scheduler
    logging.info('Building LR Scheduler')
    if args.lr_scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)*len(train_queue), 0, step)
    else:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, args.gamma, epoch)
    return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
Example #3
0
def build_imagenet(model_config, model_state_dict, **kwargs):

    # build model
    logging.info('Building Model')
    model = NASNet.build_from_config(model_config)
    model.set_bn_param(model_config['bn']['momentum'],
                       model_config['bn']['eps'])
    print(model.config)
    logging.info("param size = %d", utils.count_parameters(model))
    logging.info(
        "multi adds = %fM",
        model.get_flops(torch.ones(1, 3, 224, 224).float())[0] / 1000000)
    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)

    if torch.cuda.device_count() > 1:
        logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
        model = nn.DataParallel(model)
    model = model.cuda()

    # build criterion
    logging.info('Building Criterion')
    criterion = nn.CrossEntropyLoss().cuda()

    # build data loader
    logging.info('Building Data')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    valid_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    if args.zip_file:
        logging.info('Loading data from zip file')
        validdir = os.path.join(args.data, 'valid.zip')
        if args.lazy_load:
            valid_data = utils.ZipDataset(validdir, valid_transform)
        else:
            logging.info('Loading data into memory')
            valid_data = utils.InMemoryZipDataset(validdir,
                                                  valid_transform,
                                                  num_workers=args.num_workers)
    else:
        logging.info('Loading data from directory')
        validdir = os.path.join(args.data, 'val')
        if args.lazy_load:
            valid_data = dset.ImageFolder(validdir, valid_transform)
        else:
            logging.info('Loading data into memory')
            valid_data = utils.InMemoryDataset(validdir,
                                               valid_transform,
                                               num_workers=args.num_workers)

    logging.info('Found %d in validation data', len(valid_data))

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.eval_batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=args.num_workers)

    return valid_queue, model, criterion
Example #4
0
def build_imagenet(model_state_dict=None, optimizer_state_dict=None, **kwargs):
    ratio = kwargs.pop('ratio')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    valid_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    if args.zip_file:
        logging.info('Loading data from zip file')
        traindir = os.path.join(args.data, 'train.zip')
        if args.lazy_load:
            train_data = utils.ZipDataset(traindir, train_transform)
        else:
            logging.info('Loading data into memory')
            train_data = utils.InMemoryZipDataset(traindir,
                                                  train_transform,
                                                  num_workers=32)
    else:
        logging.info('Loading data from directory')
        traindir = os.path.join(args.data, 'train')
        if args.lazy_load:
            train_data = dset.ImageFolder(traindir, train_transform)
        else:
            logging.info('Loading data into memory')
            train_data = utils.InMemoryDataset(traindir,
                                               train_transform,
                                               num_workers=32)

    num_train = len(train_data)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(ratio * num_train))
    train_indices = sorted(indices[:split])
    valid_indices = sorted(indices[split:])

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.child_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices),
        pin_memory=True,
        num_workers=16)
    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.child_eval_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_indices),
        pin_memory=True,
        num_workers=16)

    model = NASWSNetworkImageNet(1000, args.child_layers, args.child_nodes,
                                 args.child_channels, args.child_keep_prob,
                                 args.child_drop_path_keep_prob,
                                 args.child_use_aux_head, args.steps)
    model = model.cuda()
    train_criterion = CrossEntropyLabelSmooth(1000,
                                              args.child_label_smooth).cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.child_lr,
        momentum=0.9,
        weight_decay=args.child_l2_reg,
    )
    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)
    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                args.child_decay_period,
                                                gamma=args.child_gamma)
    return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler
Example #5
0
def build_imagenet(model_state_dict, optimizer_state_dict, **kwargs):
    epoch = kwargs.pop('epoch')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    valid_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    if args.zip_file:
        logging.info('Loading data from zip file')
        traindir = os.path.join(args.data, 'train.zip')
        validdir = os.path.join(args.data, 'valid.zip')
        if args.lazy_load:
            train_data = utils.ZipDataset(traindir, train_transform)
            valid_data = utils.ZipDataset(validdir, valid_transform)
        else:
            logging.info('Loading data into memory')
            train_data = utils.InMemoryZipDataset(traindir,
                                                  train_transform,
                                                  num_workers=32)
            valid_data = utils.InMemoryZipDataset(validdir,
                                                  valid_transform,
                                                  num_workers=32)
    else:
        logging.info('Loading data from directory')
        traindir = os.path.join(args.data, 'train')
        validdir = os.path.join(args.data, 'test')  #valid
        if args.lazy_load:
            train_data = dset.ImageFolder(traindir, train_transform)
            valid_data = dset.ImageFolder(validdir, valid_transform)
        else:
            logging.info('Loading data into memory')
            train_data = utils.InMemoryDataset(traindir,
                                               train_transform,
                                               num_workers=32)
            valid_data = utils.InMemoryDataset(validdir,
                                               valid_transform,
                                               num_workers=32)

    logging.info('Found %d in training data', len(train_data))
    logging.info('Found %d in validation data', len(valid_data))

    args.steps = int(
        np.ceil(len(train_data) /
                (args.batch_size))) * torch.cuda.device_count() * args.epochs

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=16)
    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=16)
    if args.dataset == "sport8":
        model = NASNetworkImageNet(args, 8, args.layers, args.nodes,
                                   args.channels, args.keep_prob,
                                   args.drop_path_keep_prob, args.use_aux_head,
                                   args.steps, args.arch)
    elif args.dataset == "mit67":
        model = NASNetworkImageNet(args, 67, args.layers, args.nodes,
                                   args.channels, args.keep_prob,
                                   args.drop_path_keep_prob, args.use_aux_head,
                                   args.steps, args.arch)
    elif args.dataset == 'flowers102':
        model = NASNetworkImageNet(args, 102, args.layers, args.nodes,
                                   args.channels, args.keep_prob,
                                   args.drop_path_keep_prob, args.use_aux_head,
                                   args.steps, args.arch)
    else:
        model = NASNetworkImageNet(args, 1000, args.layers, args.nodes,
                                   args.channels, args.keep_prob,
                                   args.drop_path_keep_prob, args.use_aux_head,
                                   args.steps, args.arch)
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    logging.info("multi adds = %fM", model.multi_adds / 1000000)
    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)

    if torch.cuda.device_count() > 1:
        logging.info("Use %d %s", torch.cuda.device_count(), "GPUs !")
        model = nn.DataParallel(model)
    model = model.cuda()
    """if args.dataset == "sport8":
        train_criterion = CrossEntropyLabelSmooth(8, args.label_smooth).cuda()
    elif args.dataset == "mit67":
        train_criterion = CrossEntropyLabelSmooth(67, args.label_smooth).cuda()
    else:
        train_criterion = CrossEntropyLabelSmooth(1000, args.label_smooth).cuda()"""
    train_criterion = nn.CrossEntropyLoss().cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(
        model.parameters(),
        #args.lr,
        args.lr_max,
        momentum=0.9,
        weight_decay=args.l2_reg,
    )
    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)

    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, args.gamma, epoch)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), args.lr_min, epoch)
    return train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler