Beispiel #1
0
    def __init__(self, args):
        self.args = args

        self.max_epochs = args.max_epochs
        self.select_num = args.select_num
        self.population_num = args.population_num
        self.m_prob = args.m_prob
        self.crossover_num = args.crossover_num
        self.mutation_num = args.mutation_num
        self.flops_limit = args.flops_limit
        self.exp_name = args.exp_name
        # with open('sn_custom_nets_01_31.pkl', 'rb') as f: # put in correct file name 
        #     self.custom_cands = list(pickle.load(f))
        self.model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes)
        self.model = torch.nn.DataParallel(self.model).cuda()
        supernet_state_dict = torch.load(
            '../Supernet/models/' + self.exp_name + '/checkpoint-latest.pth.tar')['state_dict']
        self.model.load_state_dict(supernet_state_dict)

        self.log_dir = args.log_dir
        self.checkpoint_name = self.log_dir + '/' + self.exp_name + '/checkpoint.pth.tar'

        self.memory = []
        self.vis_dict = {}
        self.keep_top_k = {self.select_num: [], 50: []}
        self.epoch = 0
        self.candidates = []

        self.nr_layer = 20
        self.nr_state = 4
Beispiel #2
0
def test_supernet():
    """
    Test supernet(network.py)
    """
    from network import ShuffleNetV2_OneShot, get_channel_mask
    stage_repeats = [4, 8, 4, 4]
    stage_out_channels = [64, 160, 320, 640]
    candidate_scales = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]
    architecture = [0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2]
    architecture = mxnet.nd.array(architecture).astype(dtype='float32',
                                                       copy=False)
    channel_choice = (4, ) * 20
    channel_mask = get_channel_mask(channel_choice,
                                    stage_repeats,
                                    stage_out_channels,
                                    candidate_scales,
                                    dtype='float32')
    model = ShuffleNetV2_OneShot()
    print(model)
    model.hybridize()
    model._initialize(ctx=mxnet.cpu())
    test_data = mxnet.nd.random.uniform(-1, 1, shape=(5, 3, 224, 224))
    test_outputs = model(test_data, architecture, channel_mask)
    print(test_outputs.shape)
    model.collect_params().save('supernet.params')
    def __init__(self, args):
        self.args = args

        self.max_epochs = args.max_epochs
        self.select_num = args.select_num
        self.population_num = args.population_num
        self.m_prob = args.m_prob
        self.crossover_num = args.crossover_num
        self.mutation_num = args.mutation_num
        self.flops_limit = args.flops_limit

        self.model = ShuffleNetV2_OneShot()
        self.model = torch.nn.DataParallel(self.model).cuda()
        supernet_state_dict = torch.load(
            '../Supernet/models/checkpoint-latest.pth.tar')['state_dict']
        self.model.load_state_dict(supernet_state_dict)

        self.log_dir = args.log_dir
        self.checkpoint_name = os.path.join(self.log_dir, 'checkpoint.pth.tar')

        self.memory = []
        self.vis_dict = {}
        self.keep_top_k = {self.select_num: [], 50: []}
        self.epoch = 0
        self.candidates = []

        self.nr_layer = 20
        self.nr_state = 4
Beispiel #4
0
def main():
    args = get_args()
    assert args.exp_name is not None
    splits = []
    start = 0
    k = 800
    end = k
    for i in range(5):
        splits += [(start, end)]
        start += k
        end += k
    print(splits)
    # cands = generate_cand_list(12000)
    # pickle.dump(cands, open( "../data/cl3_2_1.p", "wb" ) )

    # candidate_list = pickle.load(open("../data/cl3_2_1.p", "rb"))
    candidate_list = pickle.load(open("/home/bg141/SinglePathOneShot/src/data/loc_data_ed_15.p", "rb"))

    # candidate_list = [np.fromstring(c[1:-1], dtype=int, sep=',').tolist() for c in candidate_list]


    model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes)
    model = nn.DataParallel(model)
    device = torch.device("cuda")    
    # model = model.to(device)
    model = model.cuda()
    lastest_model, iters = get_lastest_model(args.exp_name)
    print("Iters: ", iters)
    if lastest_model is not None:
        all_iters = iters
        checkpoint = torch.load(lastest_model)
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        print('load from checkpoint')

    err_list = []
    cand_list = []
    print("Split: ", splits[args.gpu])
    i = 0
    for cand in candidate_list[splits[args.gpu][0]:splits[args.gpu][1]]:
        err = get_cand_err(model, cand, args)
        err_list += [err]
        cand_list += [cand]
        i += 1
        print("Net: ", i)
        # if i%500 == 0:
        #     pickle.dump(err_list, open("./data/err-"+args.exp_name+"-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
        #     pickle.dump(cand_list, open("./data/cand-"+args.exp_name+"-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
    
    pickle.dump(err_list, open( "./data/err-"+args.exp_name+"-"+str(args.gpu)+"_ed_15.p", "wb" ) )
    pickle.dump(cand_list, open("./data/cand-"+args.exp_name+"-"+str(args.gpu)+"_ed_15.p", "wb"))

    print("Finished")
    return
Beispiel #5
0
def test_load_supernet_params():
    """
    Testing the load of supernet's params

    """
    from network import ShuffleNetV2_OneShot
    import mxnet
    model = ShuffleNetV2_OneShot(search=True)
    model.collect_params().load('supernet.params',
                                ctx=mxnet.cpu(),
                                cast_dtype=True,
                                dtype_source='saved')
    print("Done!")
Beispiel #6
0
def main3():
    args = get_args()
    assert args.exp_name is not None
    if not os.path.exists('./data/' + args.exp_name + "/"):
        os.mkdir('./data/' + args.exp_name )
    # Build candidate list
    get_random_cand = lambda:tuple(np.random.randint(4) for i in range(20))
    flops_l, flops_r, flops_step = 290, 360, 50
    bins = [[i, i+flops_step] for i in range(flops_l, flops_r, flops_step)]

    def get_uniform_sample_cand(*,timeout=500):
        idx = np.random.randint(len(bins))
        l, r = bins[idx]
        for i in range(timeout):
            cand = get_random_cand()
            if l*1e6 <= get_cand_flops(cand) <= r*1e6:
                return cand
        print("timeout")
        return get_random_cand()

    model = ShuffleNetV2_OneShot(input_size=args.im_size, n_class=args.num_classes)
    model = nn.DataParallel(model)
    device = torch.device("cuda")    
    # model = model.to(device)
    model = model.cuda()
    lastest_model, iters = get_lastest_model(args.exp_name)
    if lastest_model is not None:
        all_iters = iters
        checkpoint = torch.load(lastest_model)
        model.load_state_dict(checkpoint['state_dict'], strict=True)
        print('load from checkpoint')

    err_list = []
    cand_list = []
    i = 0
    print("GPU: ", args.gpu)
    for i in range(5000):
        cand = get_uniform_sample_cand()
        err = get_cand_err(model, cand, args)
        err_list += [err]
        cand_list += [cand]
        print("Net: ", i)
        if i%500 == 0:
            pickle.dump(err_list, open("./data/"+args.exp_name+"/err-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
            pickle.dump(cand_list, open("./data/"+args.exp_name+"/cand-"+str(args.gpu)+"-"+str(i)+".p", "wb"))
    
    pickle.dump(err_list, open("./data/"+args.exp_name+"/err-"+str(args.gpu)+".p", "wb"))
    pickle.dump(cand_list, open("./data/"+args.exp_name+"/cand-"+str(args.gpu)+".p", "wb"))

    print("Finished")
    return
Beispiel #7
0
    def __init__(self, args):
        self.args = args
        self.context = [mx.gpu(int(gpu)) for gpu in args.gpus.split(',')] if len(args.gpus.split(',')) > 0 else [mx.cpu()]
        for ctx in self.context:
            mx.random.seed(self.args.random_seed, ctx=ctx)
        np.random.seed(self.args.random_seed)
        random.seed(self.args.random_seed)

        num_gpus = len(self.args.gpus.split(','))
        batch_size = max(1, num_gpus) * self.args.batch_size
        if self.args.use_rec:
            if self.args.use_dali:
                self.train_data = dali.get_data_rec((3, self.args.input_size, self.args.input_size), self.args.crop_ratio,
                                           self.args.rec_train, self.args.rec_train_idx,
                                           self.args.batch_size, num_workers=2, train=True, shuffle=True,
                                           backend='dali-gpu', gpu_ids=[0,1], kv_store='nccl', dtype=opt.dtype,
                                           input_layout='NCHW')
                self.val_data = dali.get_data_rec((3, self.args.input_size, self.args.input_size), self.args.crop_ratio,
                                           self.args.rec_val, self.args.rec_val_idx,
                                           self.args.batch_size, num_workers=2, train=False, shuffle=False,
                                           backend='dali-gpu', gpu_ids=[0,1], kv_store='nccl', dtype=opt.dtype,
                                           input_layout='NCHW')
                self.batch_fn = batch_fn
            else:
                self.train_data, self.val_data, self.batch_fn = get_data_rec(self.args.rec_train, self.args.rec_train_idx,
                                                              self.args.rec_val, self.args.rec_val_idx,
                                                              batch_size, self.args.num_workers, self.args.random_seed)
        else:
            self.train_data, self.val_data, self.batch_fn = get_data_loader(self.args.data_dir, batch_size, self.args.num_workers)

        self.model = ShuffleNetV2_OneShot(search=True)
        self.model.collect_params().load(self.args.resume_params, ctx=self.context, cast_dtype=True, dtype_source='saved')

        self.memory = []
        self.vis_dict = {}
        self.keep_top_k = {self.args.select_num: [], 50: []}
        self.epoch = 0
        self.candidates = []

        self.nr_layer = 20
        self.nr_state = 4
        self.channel_state = 10# len(candidate_scales)
Beispiel #8
0
def test_subnet():
    """
    Test subnet(subnet.py)
    """
    from subnet import ShuffleNetV2_OneShot
    block_choice = (0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2)
    channel_choice = (6, 5, 3, 5, 2, 6, 3, 4, 2, 5, 7, 5, 4, 6, 7, 4, 4, 5, 4,
                      3)
    model = ShuffleNetV2_OneShot(input_size=224,
                                 n_class=1000,
                                 architecture=block_choice,
                                 channels_idx=channel_choice,
                                 act_type='relu',
                                 search=False)  # define a specific subnet
    model.hybridize()
    model._initialize(ctx=mxnet.cpu())
    print(model)
    test_data = mxnet.nd.random.uniform(-1, 1, shape=(5, 3, 224, 224))
    test_outputs = model(test_data)
    print(test_outputs.shape)
Beispiel #9
0
def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    if args.cifar10 == False:

        assert os.path.exists(args.train_dir)
        train_dataset = datasets.ImageFolder(
            args.train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ColorJitter(brightness=0.4,
                                       contrast=0.4,
                                       saturation=0.4),
                transforms.RandomHorizontalFlip(0.5),
                ToBGRTensor(),
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=1,
                                                   pin_memory=use_gpu)
        train_dataprovider = DataIterator(train_loader)

        assert os.path.exists(args.val_dir)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            args.val_dir,
            transforms.Compose([
                OpencvResize(256),
                transforms.CenterCrop(224),
                ToBGRTensor(),
            ])),
                                                 batch_size=200,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=use_gpu)
        val_dataprovider = DataIterator(val_loader)
        print('load imagenet data successfully')

    else:
        train_transform, valid_transform = data_transforms(args)

        trainset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_dir, 'cifar'),
                                                train=True,
                                                download=True,
                                                transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=8)
        train_dataprovider = DataIterator(train_loader)
        valset = torchvision.datasets.CIFAR10(root=os.path.join(
            args.data_dir, 'cifar'),
                                              train=False,
                                              download=True,
                                              transform=valid_transform)
        val_loader = torch.utils.data.DataLoader(valset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=8)
        val_dataprovider = DataIterator(val_loader)

        print('load cifar10 data successfully')

    model = ShuffleNetV2_OneShot()

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda step: (1.0 - step / args.total_iters)
        if step <= args.total_iters else 0,
        last_epoch=-1)

    model = model.to(device)

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    while all_iters < args.total_iters:
        all_iters = train(model,
                          device,
                          args,
                          val_interval=args.val_interval,
                          bn_process=False,
                          all_iters=all_iters)
def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(args.im_size),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=8, pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(args.val_dir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(args.im_size),
            ToBGRTensor(),
        ])),
        batch_size=200, shuffle=False,
        num_workers=8, pin_memory=use_gpu
    )
    val_dataprovider = DataIterator(val_loader)
    print('load data successfully')

    arch_path='arch.pkl'

    if os.path.exists(arch_path):
        with open(arch_path,'rb') as f:
            architecture=pickle.load(f)
    else:
        raise NotImplementedError
    channels_scales = (1.0,)*20
    model = ShuffleNetV2_OneShot(architecture=architecture, channels_scales=channels_scales, n_class=args.num_classes, input_size=args.im_size)

    print('flops:',get_flops(model))

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(args.num_classes, 0.1)

    if use_gpu:
        # model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

    # model = model.to(device)
    model = model.cuda()

    all_iters = 0
    if args.auto_continue:
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            print('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume, map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, device, args, all_iters=all_iters)
        exit(0)
    t = time.time()
    while all_iters < args.total_iters:
        all_iters = train(model, device, args, val_interval=args.val_interval, bn_process=False, all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    # all_iters = train(model, device, args, val_interval=int(1280000/args.batch_size), bn_process=True, all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)
    save_checkpoint({'state_dict': model.state_dict(),}, args.total_iters, tag='bnps-')
    print("Finished {} iters in {:.3f} seconds".format(all_iters, time.time()-t))
Beispiel #11
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(
        backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size

    np.random.seed(args.seed*args.rank)
    torch.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed_all(args.seed*args.rank)
    print('random seed: ', args.seed*args.rank)

    # create model
    print("=> creating model '{}'".format(args.model))
    if args.SinglePath:
        architecture = 20*[0]
        channels_scales = 20*[1.0]
        model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales)
        model.cuda()
        broadcast_params(model)
        for v in model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        model.log_alpha.grad = torch.zeros_like(model.log_alpha)   
    
    criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda()


    wo_wd_params = []
    wo_wd_param_names = []
    network_params = []
    network_param_names = []

    for name, mod in model.named_modules():
        if isinstance(mod, nn.BatchNorm2d):
            for key, value in mod.named_parameters():
                wo_wd_param_names.append(name+'.'+key)
        
    for key, value in model.named_parameters():
        if key != 'log_alpha':
            if value.requires_grad:
                if key in wo_wd_param_names:
                    wo_wd_params.append(value)
                else:
                    network_params.append(value)
                    network_param_names.append(key)

    params = [
        {'params': network_params,
         'lr': args.base_lr,
         'weight_decay': args.weight_decay },
        {'params': wo_wd_params,
         'lr': args.base_lr,
         'weight_decay': 0.},
    ]
    param_names = [network_param_names, wo_wd_param_names]
    if args.rank == 0:
        print('>>> params w/o weight decay: ', wo_wd_param_names)

    optimizer = torch.optim.SGD(params, momentum=args.momentum)
    if args.SinglePath:
        arch_optimizer = torch.optim.Adam(
            [param for name, param in model.named_parameters() if name == 'log_alpha'],
            lr=args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=args.arch_weight_decay
        )

    # auto resume from a checkpoint
    remark = 'imagenet_'
    remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(args.base_lr)  + '_seed_' + str(args.seed) + '_pretrain_' + str(args.pretrain_epoch)

    if args.early_fix_arch:
        remark += '_early_fix_arch'  

    if args.flops_loss:
        remark += '_flops_loss_' + str(args.flops_loss_coef)

    if args.remark != 'none':
        remark += '_'+args.remark

    args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark)
    args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark)
    generate_date = str(datetime.now().date())

    path = os.path.join(generate_date, args.save)
    if args.rank == 0:
        log_format = '%(asctime)s %(message)s'
        utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py'))
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(path, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", args)
        writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log)
    else:
        writer = None

    model_dir = path
    start_epoch = 0
    
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)

    cudnn.benchmark = True
    cudnn.enabled = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_dataset_wo_ms = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = ImagenetDataset(
        args.val_root,
        args.val_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    train_loader_wo_ms = DataLoader(
        train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    val_loader = DataLoader(
        val_dataset, batch_size=50, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer, logging)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, 85):
        train_sampler.set_epoch(epoch)
        
        if args.early_fix_arch:
            if len(model.fix_arch_index.keys()) > 0:
                for key, value_lst in model.fix_arch_index.items():
                    model.log_alpha.data[key, :] = value_lst[1]
            sort_log_alpha = torch.topk(F.softmax(model.log_alpha.data, dim=-1), 2)
            argmax_index = (sort_log_alpha[0][:,0] - sort_log_alpha[0][:,1] >= 0.3)
            for id in range(argmax_index.size(0)):
                if argmax_index[id] == 1 and id not in model.fix_arch_index.keys():
                    model.fix_arch_index[id] = [sort_log_alpha[1][id,0].item(), model.log_alpha.detach().clone()[id, :]]
            
        if args.rank == 0 and args.SinglePath:
            logging.info('epoch %d', epoch)
            logging.info(model.log_alpha)         
            logging.info(F.softmax(model.log_alpha, dim=-1))         
            logging.info('flops %fM', model.cal_flops())  

        # train for one epoch
        if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms:
            train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)
        else:
            train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)


        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, logging)
        if args.gen_max_child:
            args.gen_max_child_flag = True
            prec1 = validate(val_loader, model, criterion, epoch, writer, logging)        
            args.gen_max_child_flag = False

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(model_dir, {
                'epoch': epoch + 1,
                'model': args.model,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Beispiel #12
0
def main():
    args = get_args()

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(96),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=use_gpu)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            args.val_dir,
            transforms.Compose([
                OpencvResize(96),
                # transforms.CenterCrop(96),
                ToBGRTensor(),
            ])),
        batch_size=200,
        shuffle=False,
        num_workers=4,
        pin_memory=use_gpu)
    val_dataprovider = DataIterator(val_loader)

    arch_path = 'cl400.p'

    if os.path.exists(arch_path):
        with open(arch_path, 'rb') as f:
            architectures = pickle.load(f)
    else:
        raise NotImplementedError
    channels_scales = (1.0, ) * 20
    cands = {}
    splits = [(i, 10 + i) for i in range(0, 400, 10)]
    architectures = np.array(architectures)
    architectures = architectures[
        splits[args.split_num][0]:splits[args.split_num][1]]
    print(len(architectures))
    logging.info("Training and Validating arch: " +
                 str(splits[args.split_num]))
    for architecture in architectures:
        architecture = tuple(architecture.tolist())
        model = ShuffleNetV2_OneShot(architecture=architecture,
                                     channels_scales=channels_scales,
                                     n_class=10,
                                     input_size=96)

        print('flops:', get_flops(model))

        optimizer = torch.optim.SGD(get_parameters(model),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

        if use_gpu:
            model = nn.DataParallel(model)
            loss_function = criterion_smooth.cuda()
            device = torch.device("cuda")
        else:
            loss_function = criterion_smooth
            device = torch.device("cpu")

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: (1.0 - step / args.total_iters)
            if step <= args.total_iters else 0,
            last_epoch=-1)

        model = model.to(device)

        all_iters = 0
        if args.auto_continue:
            lastest_model, iters = get_lastest_model()
            if lastest_model is not None:
                all_iters = iters
                checkpoint = torch.load(
                    lastest_model, map_location=None if use_gpu else 'cpu')
                model.load_state_dict(checkpoint['state_dict'], strict=True)
                print('load from checkpoint')
                for i in range(iters):
                    scheduler.step()

        args.optimizer = optimizer
        args.loss_function = loss_function
        args.scheduler = scheduler
        args.train_dataprovider = train_dataprovider
        args.val_dataprovider = val_dataprovider
        # print("BEGIN VALDATE: ", args.eval, args.eval_resume)
        if args.eval:
            if args.eval_resume is not None:
                checkpoint = torch.load(
                    args.eval_resume, map_location=None if use_gpu else 'cpu')
                model.load_state_dict(checkpoint, strict=True)
                validate(model, device, args, all_iters=all_iters)
            exit(0)
        # t1,t5 = validate(model, device, args, all_iters=all_iters)
        # print("VALDATE: ", t1, "   ", t5)

        while all_iters < args.total_iters:
            all_iters = train(model,
                              device,
                              args,
                              val_interval=args.val_interval,
                              bn_process=False,
                              all_iters=all_iters)
            validate(model, device, args, all_iters=all_iters)
        all_iters = train(model,
                          device,
                          args,
                          val_interval=int(1280000 / args.batch_size),
                          bn_process=True,
                          all_iters=all_iters)
        top1, top5 = validate(model, device, args, all_iters=all_iters)
        save_checkpoint({
            'state_dict': model.state_dict(),
        },
                        args.total_iters,
                        tag='bnps-')
        cands[architecture] = [top1, top5]
        pickle.dump(
            cands,
            open("from_scratch_split_{}.pkl".format(args.split_num), 'wb'))
def train(model,
          meta_model,
          device,
          args,
          *,
          val_interval,
          bn_process=False,
          all_iters=None):

    optimizer = args.optimizer
    meta_optimizer = args.meta_optimizer
    loss_function = args.loss_function
    scheduler = args.scheduler
    meta_scheduler = args.meta_scheduler
    train_dataprovider = args.train_dataprovider

    t1 = time.time()
    Top1_err, Top5_err = 0.0, 0.0
    model.train()
    meta_model.train()
    for iters in range(1, val_interval + 1):
        scheduler.step()
        meta_scheduler.step()
        if bn_process:
            adjust_bn_momentum(model, iters)
            adjust_bn_momentum(meta_model, iters)

        all_iters += 1
        d_st = time.time()
        data, target = train_dataprovider.next()
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data_time = time.time() - d_st

        get_random_cand = lambda: tuple(
            np.random.randint(4) for i in range(20))
        flops_l, flops_r, flops_step = 290, 360, 10
        bins = [[i, i + flops_step]
                for i in range(flops_l, flops_r, flops_step)]

        def get_uniform_sample_cand(*, timeout=500):
            idx = np.random.randint(len(bins))
            l, r = bins[idx]
            for i in range(timeout):
                cand = get_random_cand()
                if l * 1e6 <= get_cand_flops(cand) <= r * 1e6:
                    return cand
            return get_random_cand()

        if iters % 5 == 1:
            cand = get_uniform_sample_cand()

        output = meta_model(data, cand)
        loss = loss_function(output, target)
        optimizer.zero_grad()
        meta_optimizer.zero_grad()
        loss.backward()

        for p in meta_model.parameters():
            if p.grad is not None and p.grad.sum() == 0:
                p.grad = None

        if iters % 5 != 0:  # step 1: update submodel
            meta_optimizer.step()

        else:  # step 2: update original model

            # # copy gradient to original model
            for p, q in zip(model.parameters(), meta_model.parameters()):
                if q.grad is not None:
                    p.grad = q.grad.clone()

            # # check
            for p, q in zip(model.parameters(), meta_model.parameters()):
                if q.grad is not None:
                    assert torch.all(torch.eq(p.grad, q.grad))

            # # update weight
            optimizer.step()

            # load weight to submodel
            meta_model.load_state_dict(model.state_dict())

            # check
            for p, q in zip(model.parameters(), meta_model.parameters()):
                if p is not None:
                    assert torch.all(torch.eq(q, p))

        prec1, prec5 = accuracy(output, target, topk=(1, 5))

        Top1_err += 1 - prec1.item() / 100
        Top5_err += 1 - prec5.item() / 100

        if all_iters % args.display_interval == 0:
            printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(all_iters, scheduler.get_lr()[0], loss.item()) + \
                        'Top-1 err = {:.6f},\t'.format(Top1_err / args.display_interval) + \
                        'Top-5 err = {:.6f},\t'.format(Top5_err / args.display_interval) + \
                        'data_time = {:.6f},\ttrain_time = {:.6f}'.format(data_time, (time.time() - t1) / args.display_interval)
            logging.info(printInfo)
            t1 = time.time()
            Top1_err, Top5_err = 0.0, 0.0

        if all_iters % args.save_interval == 0 or all_iters == 1:
            save_checkpoint({
                'state_dict': model.state_dict(),
            }, all_iters)
            checkpoint = torch.load(args.eval_resume)
            val_model = ShuffleNetV2_OneShot()
            val_model = nn.DataParallel(val_model)
            val_model = val_model.to(device)
            val_model.load_state_dict(checkpoint['state_dict'], strict=True)
            validate(val_model, device, args, all_iters=all_iters)

    return all_iters
Beispiel #14
0
def main():
    opt = parse_args()
    makedirs(opt.log_dir)
    filehandler = logging.FileHandler(opt.log_dir + '/' + opt.logging_file)
    streamhandler = logging.StreamHandler()
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    logger.info(opt)
    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167
    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    for ctx in context:
        mx.random.seed(seed_state=opt.random_seed, ctx=ctx)
    np.random.seed(opt.random_seed)
    random.seed(opt.random_seed)
    num_workers = opt.num_workers
    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    sw = SummaryWriter(logdir=opt.log_dir, flush_secs=5, verbose=False)
    optimizer = 'sgd'
    optimizer_params = {
        'wd': opt.wd,
        'momentum': opt.momentum,
        'lr_scheduler': lr_scheduler
    }
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True
    net = ShuffleNetV2_OneShot()
    net.cast(opt.dtype)
    if opt.mode == 'hybrid':
        net.hybridize()
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx,
                     batch_size, num_workers, seed):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec=rec_train,
            path_imgidx=rec_train_idx,
            preprocess_threads=num_workers,
            shuffle=True,
            batch_size=batch_size,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
            rand_mirror=True,
            random_resized_crop=True,
            max_aspect_ratio=4. / 3.,
            min_aspect_ratio=3. / 4.,
            max_random_area=1,
            min_random_area=0.08,
            brightness=jitter_param,
            saturation=jitter_param,
            contrast=jitter_param,
            pca_noise=lighting_param,
            shuffle_chunk_seed=seed,
            seed=seed,
            seed_aug=seed,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec=rec_val,
            path_imgidx=rec_val_idx,
            preprocess_threads=num_workers,
            shuffle=False,
            batch_size=batch_size,
            resize=resize,
            data_shape=(3, input_size, input_size),
            mean_r=mean_rgb[0],
            mean_g=mean_rgb[1],
            mean_b=mean_rgb[2],
            std_r=std_rgb[0],
            std_g=std_rgb[1],
            std_b=std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param,
                                         contrast=jitter_param,
                                         saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(), normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(), normalize
        ])

        train_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=True).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)
        val_data = gluon.data.DataLoader(imagenet.classification.ImageNet(
            data_dir, train=False).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        if opt.use_dali:
            train_data = dali.get_data_rec((3, opt.input_size, opt.input_size),
                                           opt.crop_ratio,
                                           opt.rec_train,
                                           opt.rec_train_idx,
                                           opt.batch_size,
                                           num_workers=2,
                                           train=True,
                                           shuffle=True,
                                           backend='dali-gpu',
                                           gpu_ids=[0, 1],
                                           kv_store='nccl',
                                           dtype=opt.dtype,
                                           input_layout='NCHW')
            val_data = dali.get_data_rec((3, opt.input_size, opt.input_size),
                                         opt.crop_ratio,
                                         opt.rec_val,
                                         opt.rec_val_idx,
                                         opt.batch_size,
                                         num_workers=2,
                                         train=False,
                                         shuffle=False,
                                         backend='dali-gpu',
                                         gpu_ids=[0, 1],
                                         kv_store='nccl',
                                         dtype=opt.dtype,
                                         input_layout='NCHW')

            def batch_fn(batch, ctx):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                return data, label
        else:
            train_data, val_data, batch_fn = get_data_rec(
                opt.rec_train, opt.rec_train_idx, opt.rec_val, opt.rec_val_idx,
                batch_size, num_workers, opt.random_seed)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes,
                           on_value=1 - eta + eta / classes,
                           off_value=eta / classes)
            y2 = l[::-1].one_hot(classes,
                                 on_value=1 - eta + eta / classes,
                                 off_value=eta / classes)
            res.append(lam * y1 + (1 - lam) * y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes,
                            on_value=1 - eta + eta / classes,
                            off_value=eta / classes)
            smoothed.append(res)
        return smoothed

    def test(net,
             batch_fn,
             ctx,
             train_data,
             val_data,
             cand,
             channel_mask,
             update_images=20000,
             update_bn=False):
        if update_bn:
            if opt.use_rec:
                train_data.reset()
            net.cast('float32')
            for k, v in net._children.items():
                if isinstance(v, BatchNormNAS):
                    v.inference_update_stat = True
            for i, batch in enumerate(train_data):
                if (i + 1) * opt.batch_size * len(ctx) >= update_images:
                    break
                data, _ = batch_fn(train_data)
                _ = [
                    net(
                        X.astype('float32', copy=False),
                        cand.as_in_context(X.context).astype('float32',
                                                             copy=False),
                        channel_mask.as_in_context(X.context).astype(
                            'float32', copy=False)) for X in data
                ]
            for k, v in net._children.items():
                if isinstance(v, BatchNormNAS):
                    v.inference_update_stat = False
            net.cast(opt.dtype)
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [
                net(X.astype(opt.dtype, copy=False),
                    cand.as_in_context(X.context),
                    channel_mask.as_in_context(X.context)) for X in data
            ]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)
        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (top1, top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            net._initialize(ctx=ctx, force_reinit=True)
        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True

        L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 0
        iteration = 0

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()
            get_random_cand = lambda x: tuple(
                np.random.randint(x) for i in range(20))

            #Get random channel mask
            if epoch < 10:
                channel = (9, ) * 20  #Firstly, train full supernet
            elif epoch < 15:
                channel = tuple(
                    np.random.randint(2) + 8
                    for i in range(20))  #the channel choice is 8 ~ 9
            elif epoch < 20:
                channel = tuple(
                    np.random.randint(3) + 7
                    for i in range(20))  #the channel choice is 7 ~ 9
            elif epoch < 25:
                channel = tuple(
                    np.random.randint(4) + 6
                    for i in range(20))  #the channel choice is 6 ~ 9
            elif epoch < 30:
                channel = tuple(
                    np.random.randint(5) + 5
                    for i in range(20))  #the channel choice is 5 ~ 9
            elif epoch < 35:
                channel = tuple(
                    np.random.randint(6) + 4
                    for i in range(20))  #the channel choice is 4 ~ 9
            elif epoch < 45:
                channel = tuple(
                    np.random.randint(7) + 3
                    for i in range(20))  #the channel choice is 3 ~ 9
            elif epoch < 50:
                channel = tuple(
                    np.random.randint(8) + 2
                    for i in range(20))  #the channel choice is 2 ~ 9
            elif epoch < 55:
                channel = tuple(
                    np.random.randint(9) + 1
                    for i in range(20))  #the channel choice is 1 ~ 9
            else:
                channel = tuple(
                    np.random.randint(10)
                    for i in range(20))  #the channel choice is 0 ~ 9
            print('Defined Channel Choice: ', channel)
            channel_mask = get_channel_mask(channel,
                                            stage_repeats,
                                            stage_out_channels,
                                            candidate_scales,
                                            dtype=opt.dtype)

            for i, batch in enumerate(train_data):
                # Generate channel mask and random block choice
                cand = get_random_cand(4)
                print('Random Block Candidate: ', cand)
                cand = nd.array(cand)
                cand = cand.astype(opt.dtype, copy=False)
                #print(channel_mask)
                data, label = batch_fn(batch, ctx)
                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam * X + (1 - lam) * X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                with ag.record():
                    outputs = [
                        net(X.astype(opt.dtype, copy=False),
                            cand.as_in_context(X.context),
                            channel_mask.as_in_context(X.context))
                        for X in data
                    ]
                    loss = [
                        L(yhat, y.astype(opt.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]
                for l in loss:
                    l.backward()
                sw.add_scalar(tag='train_loss',
                              value=sum([l.sum().asscalar()
                                         for l in loss]) / len(loss),
                              global_step=iteration)

                trainer.step(batch_size, ignore_stale_grad=True)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)
                train_metric_name, train_metric_score = train_metric.get()
                sw.add_scalar(
                    tag='train_{}_curves'.format(train_metric_name),
                    value=('train_{}_value'.format(train_metric_name),
                           train_metric_score),
                    global_step=iteration)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'
                        % (epoch, i, batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score, trainer.learning_rate))
                    btic = time.time()
                iteration += 1
            if epoch == 0:
                sw.add_graph(net)

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))

            # Generate channel mask and random block choice
            cand = get_random_cand(4)
            cand = nd.array(cand)
            cand = cand.astype(opt.dtype, copy=False)
            #channel_mask = get_channel_mask(channel, stage_repeats, stage_out_channels, candidate_scales, dtype=opt.dtype)

            top1_val_acc, top5_val_acc = test(net,
                                              batch_fn,
                                              ctx,
                                              train_data,
                                              val_data,
                                              cand,
                                              channel_mask,
                                              update_images=20000,
                                              update_bn=False)
            sw.add_scalar(tag='val_acc_curves',
                          value=('valid_acc_value', top1_val_acc),
                          global_step=epoch)
            logger.info('[Epoch %d] training: %s=%f' %
                        (epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info('[Epoch %d] validation: top1_acc=%f top5_acc=%f' %
                        (epoch, top1_val_acc, top5_val_acc))

            if top1_val_acc > best_val_score:
                best_val_score = top1_val_acc
                net.collect_params().save(
                    '%s/%.4f-supernet_imagenet-%d-best.params' %
                    (save_dir, best_val_score, epoch))
                trainer.save_states(
                    '%s/%.4f-supernet_imagenet-%d-best.states' %
                    (save_dir, best_val_score, epoch))

            if save_frequency and save_dir and (epoch +
                                                1) % save_frequency == 0:
                net.collect_params().save('%s/supernet_imagenet-%d.params' %
                                          (save_dir, epoch))
                trainer.save_states('%s/supernet_imagenet-%d.states' %
                                    (save_dir, epoch))

        sw.close()
        if save_frequency and save_dir:
            net.collect_params().save('%s/supernet_imagenet-%d.params' %
                                      (save_dir, opt.num_epochs - 1))
            trainer.save_states('%s/supernet_imagenet-%d.states' %
                                (save_dir, opt.num_epochs - 1))

    train(context)
Beispiel #15
0
def main():
    args = get_args()
    args.world_size = args.gpus * args.nodes
    args.rank = args.gpus * args.nr + args.local_rank
    print("RANK: " + str(args.rank) + ", LOCAL RANK: " + str(args.local_rank))

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('/home/admin/aihub/SinglePathOneShot/log'):
        os.mkdir('/home/admin/aihub/SinglePathOneShot/log')
    fh = logging.FileHandler(
        os.path.join(
            '/home/admin/aihub/SinglePathOneShot/log/train-{}{:02}{}'.format(
                local_time.tm_year % 2000, local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    assert os.path.exists(args.train_dir)
    train_dataset = datasets.ImageFolder(
        args.train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor(),
        ]))

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=args.rank)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=32,
                                               pin_memory=True,
                                               sampler=train_sampler)
    train_dataprovider = DataIterator(train_loader)

    assert os.path.exists(args.val_dir)
    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        args.val_dir,
        transforms.Compose([
            OpencvResize(256),
            transforms.CenterCrop(224),
            ToBGRTensor(),
        ])),
                                             batch_size=200,
                                             shuffle=False,
                                             num_workers=32,
                                             pin_memory=use_gpu)
    val_dataprovider = DataIterator(val_loader)

    print('load data successfully')

    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=args.local_rank)
    #     dist.init_process_group(backend='nccl', init_method='tcp://'+args.ip+':'+str(args.port), world_size=args.world_size, rank=args.rank)
    #     dist.init_process_group(backend='nccl', init_method="file:///mnt/nas1/share_file", world_size=args.world_size, rank=args.rank)
    torch.cuda.set_device(args.local_rank)

    channels_scales = (1.0, ) * 20
    model = ShuffleNetV2_OneShot(architecture=list(args.arch),
                                 channels_scales=channels_scales)
    device = torch.device(args.local_rank)
    model = model.cuda(args.local_rank)

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda step: (1.0 - step / args.total_iters)
        if step <= args.total_iters else 0,
        last_epoch=-1)

    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank],
        find_unused_parameters=False)  #,output_device=args.local_rank) # ,
    loss_function = criterion_smooth.cuda()

    all_iters = 0

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model, device, args, all_iters=all_iters)
        exit(0)

    validate(model, device, args, all_iters=all_iters)

    while all_iters < args.total_iters:
        all_iters = train(model,
                          device,
                          args,
                          val_interval=args.val_interval,
                          bn_process=False,
                          all_iters=all_iters)
        validate(model, device, args, all_iters=all_iters)
    all_iters = train(model,
                      device,
                      args,
                      val_interval=int(1280000 / args.val_batch_size),
                      bn_process=True,
                      all_iters=all_iters)
    validate(model, device, args, all_iters=all_iters)