Example #1
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = False
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.deterministic = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    if args.loss_func == 'cce':
        criterion = nn.CrossEntropyLoss().cuda()
    elif args.loss_func == 'rll':
        criterion = utils.RobustLogLoss(alpha=args.alpha).cuda()
    else:
        assert False, "Invalid loss function '{}' given. Must be in {'cce', 'rll'}".format(
            args.loss_func)

    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
    model = model.cuda()
    model.train()
    model.apply(weights_init)
    nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # Load dataset
    if args.dataset == 'cifar10':
        train_data = CIFAR10(root=args.data,
                             train=True,
                             gold=False,
                             gold_fraction=0.0,
                             corruption_prob=args.corruption_prob,
                             corruption_type=args.corruption_type,
                             transform=train_transform,
                             download=True,
                             seed=args.seed)
        gold_train_data = CIFAR10(root=args.data,
                                  train=True,
                                  gold=True,
                                  gold_fraction=1.0,
                                  corruption_prob=args.corruption_prob,
                                  corruption_type=args.corruption_type,
                                  transform=train_transform,
                                  download=True,
                                  seed=args.seed)
    elif args.dataset == 'cifar100':
        train_data = CIFAR100(root=args.data,
                              train=True,
                              gold=False,
                              gold_fraction=0.0,
                              corruption_prob=args.corruption_prob,
                              corruption_type=args.corruption_type,
                              transform=train_transform,
                              download=True,
                              seed=args.seed)
        gold_train_data = CIFAR100(root=args.data,
                                   train=True,
                                   gold=True,
                                   gold_fraction=1.0,
                                   corruption_prob=args.corruption_prob,
                                   corruption_type=args.corruption_type,
                                   transform=train_transform,
                                   download=True,
                                   seed=args.seed)
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    clean_train_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)
    noisy_train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=0)

    clean_valid_queue = torch.utils.data.DataLoader(
        gold_train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)
    noisy_valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
        pin_memory=True,
        num_workers=0)

    clean_train_list, clean_valid_list, noisy_train_list, noisy_valid_list = [], [], [], []
    for dst_list, queue in [
        (clean_train_list, clean_train_queue),
        (clean_valid_list, clean_valid_queue),
        (noisy_train_list, noisy_train_queue),
        (noisy_valid_list, noisy_valid_queue),
    ]:
        for input, target in queue:
            input = Variable(input, volatile=True).cuda()
            target = Variable(target, volatile=True).cuda(async=True)
            dst_list.append((input, target))

    for epoch in range(args.epochs):
        logging.info('Epoch %d, random architecture with fix weights', epoch)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        logging.info(F.softmax(model.alphas_normal, dim=-1))
        logging.info(F.softmax(model.alphas_reduce, dim=-1))

        # training
        clean_train_acc, clean_train_obj = infer(clean_train_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_train')
        logging.info('clean_train_acc %f, clean_train_loss %f',
                     clean_train_acc, clean_train_obj)

        noisy_train_acc, noisy_train_obj = infer(noisy_train_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_train')
        logging.info('noisy_train_acc %f, noisy_train_loss %f',
                     noisy_train_acc, noisy_train_obj)

        # validation
        clean_valid_acc, clean_valid_obj = infer(clean_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='clean_valid')
        logging.info('clean_valid_acc %f, clean_valid_loss %f',
                     clean_valid_acc, clean_valid_obj)

        noisy_valid_acc, noisy_valid_obj = infer(noisy_valid_list,
                                                 model,
                                                 criterion,
                                                 kind='noisy_valid')
        logging.info('noisy_valid_acc %f, noisy_valid_loss %f',
                     noisy_valid_acc, noisy_valid_obj)

        utils.save(model, os.path.join(args.save, 'weights.pt'))

        # Randomly change the alphas
        k = sum(1 for i in range(model._steps) for n in range(2 + i))
        num_ops = len(PRIMITIVES)
        model.alphas_normal.data.copy_(torch.randn(k, num_ops))
        model.alphas_reduce.data.copy_(torch.randn(k, num_ops))
Example #2
0
def main():
    start = time.time()
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    torch.cuda.set_device(config.local_rank % len(config.gpus))
    torch.distributed.init_process_group(backend='nccl',
                                         init_method = 'env://')
    config.world_size=torch.distributed.get_world_size()
    config.total_batch=config.world_size * config.batch_size

    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    torch.backends.cudnn.benchmark=True

    CLASSES=1000
    channels=SEARCH_SPACE['channel_size']
    strides=SEARCH_SPACE['strides']

    # Model
    model=Network(channels, strides, CLASSES)
    model=model.to(device)
    model.apply(utils.weights_init)
    model=DDP(model, delay_allreduce = True)
    # For solve the custome loss can`t use model.parameters() in apex warpped model via https://github.com/NVIDIA/apex/issues/457 and https://github.com/NVIDIA/apex/issues/107
    # model = torch.nn.parallel.DistributedDataParallel(
    #    model, device_ids=[config.local_rank], output_device=config.local_rank)
    logger.info("param size = %fMB", utils.count_parameters_in_MB(model))

    if config.target_hardware is None:
        config.ref_value=None
    else:
        config.ref_value=ref_values[config.target_hardware]['%.2f' %
                                                              config.width_mult]

    # Loss
    criterion = LatencyLoss(config, channels, strides).cuda(config.gpus)
    normal_critersion = nn.CrossEntropyLoss()

    alpha_weight = model.module.arch_parameters()
    # weight = [param for param in model.parameters() if not utils.check_tensor_in_list(param, alpha_weight)]
    weight = model.weight_parameters()
    # Optimizer
    w_optimizer = torch.optim.SGD(
        weight,
        config.w_lr,
        momentum=config.w_momentum,
        weight_decay=config.w_weight_decay)

    alpha_optimizer = torch.optim.Adam(alpha_weight,
                                       lr=config.alpha_lr, betas=(config.arch_adam_beta1, config.arch_adam_beta2), eps=config.arch_adam_eps, weight_decay=config.alpha_weight_decay)

    train_data = get_imagenet_iter_torch(
        type='train',
        # image_dir="/googol/atlas/public/cv/ILSVRC/Data/"
        # use soft link `mkdir ./data/imagenet && ln -s /googol/atlas/public/cv/ILSVRC/Data/CLS-LOC/* ./data/imagenet/`
        image_dir=config.data_path+config.dataset.lower(),
        batch_size=config.batch_size,
        num_threads=config.workers,
        world_size=config.world_size,
        local_rank=config.local_rank,
        crop=224, device_id=config.local_rank, num_gpus=config.gpus, portion=config.train_portion
    )
    valid_data = get_imagenet_iter_torch(
        type='val',
        # image_dir="/googol/atlas/public/cv/ILSVRC/Data/"
        # use soft link `mkdir ./data/imagenet && ln -s /googol/atlas/public/cv/ILSVRC/Data/CLS-LOC/* ./data/imagenet/`
        image_dir=config.data_path+"/"+config.dataset.lower(),
        batch_size=config.batch_size,
        num_threads=config.workers,
        world_size=config.world_size,
        local_rank=config.local_rank,
        crop=224, device_id=config.local_rank, num_gpus=config.gpus, portion=config.val_portion
    )

    
    best_top1 = 0.
    best_genotype = list()
    lr = 0

    config.start_epoch = -1
    config.warmup_epoch = 0
    config.warmup = True
    ### Resume form warmup model or train model ###
    if config.resume:
        try:
            model_path = config.path + '/checkpoint.pth.tar'
            model, w_optimizer, alpha_optimizer =  load_model(model, model_fname=model_path, optimizer = w_optimizer, arch_optimizer = alpha_optimizer)
        except Exception:
            warmup_path = config.path +  '/warmup.pth.tar'
            if os.path.exists(warmup_path):
                print('load warmup weights')
                model, w_optimizer,alpha_optimizer =  load_model(model, model_fname=warmup_path,
                           optimizer=w_optimizer, arch_optimizer=alpha_optimizer)
            else:
                print('fail to load models')

    w_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optimizer, float(config.epochs), eta_min=config.w_lr_min)
    if config.start_epoch < 0 and config.warm_up:
        for epoch in range(config.warmup_epoch, config.warmup_epochs):
            # warmup
            train_top1, train_loss = warm_up(train_data, valid_data, model,
                                            normal_critersion, criterion, w_optimizer,epoch, writer)
            config.start_epoch = epoch

    update_schedule =  utils.get_update_schedule_grad(len(train_data), config)
    for epoch in range(config.start_epoch + 1, config.epochs):
        if epoch > config.warmup_epochs:
            w_scheduler.step()
            lr = w_scheduler.get_lr()[0]
            logger.info('epoch %d lr %e', epoch, lr)
        # training
        train_top1, train_loss = train(train_data, valid_data, model,
                                           normal_critersion, criterion, w_optimizer, alpha_optimizer, lr, epoch, writer, update_schedule)
        logger.info('Train top1 %f', train_top1)

        # validation
        top1 = 0
        if epoch % 10 == 0:
            top1, loss = infer(valid_data, model, epoch, criterion, normal_critersion, writer)
            logger.info('valid top1 %f', top1)

        genotype = model.module.genotype()
        logger.info("genotype = {}".format(genotype))

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        save_model(model, {
            'warmup': False,
            'epoch': epoch,
            'w_optimizer': w_optimizer.state_dict(),
            'alpha_optimizer': alpha_optimizer.state_dict(),
            'state_dict': model.state_dict()
        }, is_best=is_best)

    utils.time(time.time() - start)
    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))