Beispiel #1
0
def test_transforms():
    def transforms(img, label):
        return np.zeros([3, 3, 3]), 1

    ds = datasets.ImageNet('train', transforms=transforms)
    assert (ds[0][0] == np.zeros([3, 3, 3])).all()
    assert ds[0][1] == 1

    ds = datasets.ImageNet('val', transforms=transforms)
    assert (ds[0][0] == np.zeros([3, 3, 3])).all()
    assert ds[0][1] == 1
Beispiel #2
0
def test_imagenet_class_to_idx_train():
    for ds in [datasets.ImageNet('train'), datasets.ImageNet('val')]:
        assert ds.classes[0] == ('tench', 'Tinca tinca')
        assert ds.classes[312] == ('cricket', )
        assert ds.classes[999] == ('toilet tissue', 'toilet paper',
                                   'bathroom tissue')

        assert ds.class_to_idx['tench'] == 0
        assert ds.class_to_idx['Tinca tinca'] == 0

        assert ds.class_to_idx['cricket'] == 312

        assert ds.class_to_idx['toilet tissue'] == 999
        assert ds.class_to_idx['toilet paper'] == 999
        assert ds.class_to_idx['bathroom tissue'] == 999
Beispiel #3
0
def test_transform_splits():
    def transform1(img):
        return np.zeros([3, 3, 3])

    def transform2(label):
        return 1

    ds = datasets.ImageNet('train',
                           transform=transform1,
                           target_transform=transform2)
    assert (ds[0][0] == np.zeros([3, 3, 3])).all()
    assert ds[0][1] == 1

    ds = datasets.ImageNet('val',
                           transform=transform1,
                           target_transform=transform2)
    assert (ds[0][0] == np.zeros([3, 3, 3])).all()
    assert ds[0][1] == 1
Beispiel #4
0
    def set_dataset_images(self):
        transform_list = [transforms.ToTensor()]
        if self.resize:
            transform_list.append(transforms.Resize(256))
            transform_list.append(transforms.CenterCrop(224))
        if not self.rgb:
            transform_list.append(transforms.Grayscale(num_output_channels=3))

        transform = transforms.Compose(transform_list)
        self.dataset_images = datasets.ImageNet(location=self.location, transform=transform)
Beispiel #5
0
misc.prepare_logging(args)

print('==> Preparing data..')

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.25, 1.0)),
    transforms.RandomHorizontalFlip(),
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
])

train_loader = torch.utils.data.DataLoader(datasets.ImageNet(
    args.data, 'train', transform_train),
                                           batch_size=args.train_batch_size,
                                           shuffle=True,
                                           num_workers=32,
                                           pin_memory=True,
                                           collate_fn=datasets.fast_collate)
test_loader = torch.utils.data.DataLoader(datasets.ImageNet(
    args.data, 'val', transform_test),
                                          batch_size=50,
                                          shuffle=False,
                                          num_workers=32,
                                          pin_memory=True,
                                          collate_fn=datasets.fast_collate)
print('==> Initializing model...')
model = models.__dict__[args.arch](args.num_classes, args.expanded_inchannel,
                                   args.multiplier)
Beispiel #6
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(
        backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size


    np.random.seed(args.seed*args.rank)
    torch.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed_all(args.seed*args.rank)

    # create model
    print("=> creating model '{}'".format(args.model))
    if args.SinglePath:
        architecture = 20*[0]
        channels_scales = 20*[1.0]
        #load derived child network
        log_alpha = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device()))['state_dict']['log_alpha']
        weights = torch.zeros_like(log_alpha).scatter_(1, torch.argmax(log_alpha, dim = -1).view(-1,1), 1)
        model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales, weights=weights)
        model.cuda()
        broadcast_params(model)
        for v in model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        model.log_alpha.grad = torch.zeros_like(model.log_alpha)   
        if not args.retrain:
            load_state_ckpt(args.checkpoint_path, model)
            checkpoint = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device()))
            args.base_lr = checkpoint['optimizer']['param_groups'][0]['lr']
        if args.reset_bn_stat:
            model._reset_bn_running_stats()

    # define loss function (criterion) and optimizer
    criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda()

    wo_wd_params = []
    wo_wd_param_names = []
    network_params = []
    network_param_names = []

    for name, mod in model.named_modules():
        #if isinstance(mod, (nn.BatchNorm2d, SwitchNorm2d)):
        if isinstance(mod, nn.BatchNorm2d):
            for key, value in mod.named_parameters():
                wo_wd_param_names.append(name+'.'+key)
        
    for key, value in model.named_parameters():
        if key != 'log_alpha':
            if value.requires_grad:
                if key in wo_wd_param_names:
                    wo_wd_params.append(value)
                else:
                    network_params.append(value)
                    network_param_names.append(key)

    params = [
        {'params': network_params,
         'lr': args.base_lr,
         'weight_decay': args.weight_decay },
        {'params': wo_wd_params,
         'lr': args.base_lr,
         'weight_decay': 0.},
    ]
    param_names = [network_param_names, wo_wd_param_names]
    if args.rank == 0:
        print('>>> params w/o weight decay: ', wo_wd_param_names)
    optimizer = torch.optim.SGD(params, momentum=args.momentum)
    arch_optimizer=None

    # auto resume from a checkpoint
    remark = 'imagenet_'
    remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(float("{0:.2f}".format(args.base_lr))) + '_seed_' + str(args.seed)

    if args.remark != 'none':
        remark += '_'+args.remark

    args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark)
    args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark)
    generate_date = str(datetime.now().date())

    path = os.path.join(generate_date, args.save)
    if args.rank == 0:
        log_format = '%(asctime)s %(message)s'
        utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py'))
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(path, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", args)
        writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log)
    else:
        writer = None

    #model_dir = args.model_dir
    model_dir = path
    start_epoch = 0
    
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize])
    train_dataset = datasets.ImageNet(split='train', transform=transform)

    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize])
    train_dataset_wo_ms = datasets.ImageNet(split='train', transform=transform)

    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize])
    val_dataset = datasets.ImageNet(split='val', transform=transform)

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    train_loader_wo_ms = DataLoader(
        train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    val_loader = DataLoader(
        val_dataset, batch_size=50, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer, logging)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)
        
        if args.rank == 0 and args.SinglePath:
            logging.info('epoch %d', epoch)
        
        # evaluate on validation set after loading the model
        if epoch == 0 and not args.reset_bn_stat:
            prec1 = validate(val_loader, model, criterion, epoch, writer, logging)
       
         # train for one epoch
        if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms and args.retrain:
            train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)
        else:
            train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)


        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, logging)

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(model_dir, {
                'epoch': epoch + 1,
                'model': args.model,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
def main():
    timer = skeleton.utils.Timer()
    args = parse_args()
    if args.checkpoint is None:
        raise ValueError('must be a set --checkpoint')

    log_format = '[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)03d] %(message)s'
    level = logging.DEBUG if args.debug else logging.INFO
    if not args.log_filename:
        logging.basicConfig(level=level, format=log_format, stream=sys.stderr)
    else:
        logging.basicConfig(level=level,
                            format=log_format,
                            filename=args.log_filename)
    torch.backends.cudnn.benchmark = True
    if args.seed is not None:
        skeleton.utils.set_random_seed_all(args.seed, deterministic=False)

    assert 'resnet' in args.architecture
    assert args.architecture.split('-')[1] in ['18', '34', '50', '101']

    if args.local_rank >= 0:
        logging.info('Distributed: wait dist process group:%d',
                     args.local_rank)
        dist.init_process_group(backend=args.dist_backend,
                                init_method='env://',
                                world_size=int(os.environ['WORLD_SIZE']))
        assert (int(os.environ['WORLD_SIZE']) == dist.get_world_size())
        logging.info('Distributed: success device:%d (%d/%d)', args.local_rank,
                     dist.get_rank(), dist.get_world_size())

        world_size = dist.get_world_size()
        world_rank = dist.get_rank()
        num_gpus = 1
    else:
        logging.info('Single process')
        args.local_rank = 0
        world_size = 1
        world_rank = 0
        num_gpus = torch.cuda.device_count()

    if world_rank == 0:
        summary_writers = {
            'train': SummaryWriter('%s/train' % args.checkpoint),
            'valid': SummaryWriter('%s/valid' % args.checkpoint),
            'valid_ema': SummaryWriter('%s/valid_ema' % args.checkpoint),
        }

    environments = skeleton.utils.Environments()
    device = torch.device('cuda', args.local_rank)
    torch.cuda.set_device(device)
    LOGGER.info('environemtns\n%s', environments)
    LOGGER.info('args\n%s', args)

    batch = args.batch * num_gpus * (2 if args.half else 1)
    total_batch = batch * world_size
    steps_per_epoch = int(1281167 // total_batch)
    LOGGER.info(
        'other stats\nnum_gpus:%d\nbatch:%d\ntotal_batch:%d\nsteps_per_epoch:%d',
        num_gpus, batch, total_batch, steps_per_epoch)

    input_size = 224
    num_classes = 1000
    norm_layer = torch.nn.SyncBatchNorm if world_size > 1 and args.sync_bn else torch.nn.BatchNorm2d
    model = torchvision.models.resnet18(
        norm_layer=norm_layer,
        zero_init_residual=True) if '18' in args.architecture else None
    model = torchvision.models.resnet34(
        norm_layer=norm_layer,
        zero_init_residual=True) if '34' in args.architecture else model
    model = torchvision.models.resnet50(
        norm_layer=norm_layer,
        zero_init_residual=True) if '50' in args.architecture else model
    model = torchvision.models.resnet101(
        norm_layer=norm_layer,
        zero_init_residual=True) if '101' in args.architecture else model
    model = model.to(device=device)

    def kernel_initializer(module):
        if isinstance(module, torch.nn.Conv2d):
            torch.nn.init.kaiming_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.constant_(module.bias, 0)

    model.apply(kernel_initializer)

    if args.half:
        for module in model.modules():
            if not isinstance(module,
                              (torch.nn.BatchNorm2d, torch.nn.SyncBatchNorm)):
                module.half()
    model_ema_eval = skeleton.nn.ExponentialMovingAverage(
        copy.deepcopy(model), mu=0.9999,
        data_parallel=world_size == 1).to(device=device).eval()

    epochs = args.epoch
    learning_rate = 0.1
    weight_decay = 1e-4
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay,
                                nesterov=True)
    if args.schedule == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, [(l + 5) * steps_per_epoch for l in [30, 60, 80]],
            gamma=0.1)
    elif args.schedule == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=epochs *
                                                               steps_per_epoch,
                                                               eta_min=1e-6)
    else:
        raise ValueError('not support schedule: %s', args.schedule)

    scheduler = skeleton.optim.GradualWarmup(optimizer,
                                             scheduler,
                                             steps=5 * steps_per_epoch,
                                             multiplier=total_batch / 256)

    loss_scaler = 1.0 if not args.half else 1024.0
    if args.loss_label_smooth > 0.0:
        loss_fn = skeleton.nn.CrossEntropyLabelSmooth(num_classes=1000,
                                                      epsilon=0.1,
                                                      reduction='mean')
    else:
        loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')

    def criterion(logits, targets):
        return loss_scaler * loss_fn(logits, targets)

    metricer = skeleton.nn.AccuracyMany((1, 5))
    meter = AverageMeter('loss', 'accuracy', 'accuracy_top5').to(device=device)

    def metrics_calculator(logits, targets):
        with torch.no_grad():
            loss = criterion(logits, targets)
            top1, top5 = metricer(logits, targets)
            return {
                'loss': loss.detach() / loss_scaler,
                'accuracy': top1.detach(),
                'accuracy_top5': top5.detach()
            }

    # profiler = skeleton.nn.Profiler(model)
    # params = profiler.params()
    # flops = profiler.flops(torch.ones(1, 3, input_size, input_size, dtype=torch.float, device=device))
    # LOGGER.info('arechitecture\n%s\ninput:%d\nprarms:%.2fM\nGFLOPs:%.3f', args.architecture, input_size, params / (1024 * 1024), flops / (1024 * 1024 * 1024))

    LOGGER.info('arechitecture:%s\ninput:%d\nnum_classes:%d',
                args.architecture, input_size, num_classes)
    LOGGER.info(
        'optimizers\nloss:%s\noptimizer:%s\nscheduler:%s\nloss_scaler:%f',
        str(criterion), str(optimizer), str(scheduler), loss_scaler)
    LOGGER.info(
        'hyperparams\nbatch:%d\ninput_size:%d\nsteps_per_epoch:%d\nlearning_rate_init:%.4f',
        batch, input_size, steps_per_epoch, learning_rate)

    dataset = datasets.ImageNet(
        split='train',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.RandomResizedCrop(
                input_size, scale=(0.05, 1.0), interpolation=Image.BICUBIC),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ColorJitter(brightness=0.12,
                                               contrast=0.5,
                                               saturation=0.5,
                                               hue=0.2),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
            # torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]))
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=world_rank, shuffle=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch,
                                             sampler=sampler,
                                             num_workers=args.workers,
                                             drop_last=True,
                                             pin_memory=True)
    steps = len(dataloader)

    resize_image = input_size + 32
    dataset = datasets.ImageNet(
        split='val',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(resize_image,
                                          interpolation=Image.BICUBIC),
            torchvision.transforms.CenterCrop(input_size),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
            # torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]))
    dataloader_val = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 drop_last=False,
                                                 pin_memory=True)

    if world_size > 1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
        for param in model.parameters():
            dist.broadcast(param.data, 0)
    else:
        model = torch.nn.parallel.DataParallel(model)
    torch.cuda.synchronize()

    best_accuracy = 0.0
    timer('init', reset_step=True, exclude_total=True)
    for epoch in range(epochs):
        model.train()
        sampler.set_epoch(epoch)  # re-shuffled dataset per node

        for step, (inputs, targets) in zip(range(steps), dataloader):
            timer('init', reset_step=True, exclude_total=True)
            inputs = inputs.to(device=device,
                               dtype=torch.half if args.half else torch.float,
                               non_blocking=True)
            targets = targets.to(device=device, non_blocking=True)

            logits = model(inputs).to(dtype=torch.float)
            loss = criterion(logits, targets)
            timer('forward')

            optimizer.zero_grad()
            loss.backward()
            if loss_scaler != 1.0:
                for param in model.parameters():
                    param.grad.data /= loss_scaler
            timer('backward')

            optimizer.step()
            scheduler.step()

            model_ema_eval.update(model.module, step=epoch * steps + step)
            timer('optimize')

            meter.updates(metrics_calculator(logits, targets))

            if step % (steps // 100) == 0:
                LOGGER.info(
                    '[train] [rank:%03d] %03d/%03d epoch (%02d%%) | lr:%.4f | %s',
                    world_rank, epoch, epochs, 100.0 * step / steps,
                    scheduler.get_lr()[0], str(meter))
            timer('remain')

        metrics_train = meter.get()
        logging.info('[train] [rank:%03d] %03d/%03d epoch | %s', world_rank,
                     epoch, epochs, str(meter))
        if world_rank == 0:
            print('[train] [rank:%03d] %03d/%03d epoch | %s' %
                  (world_rank, epoch, epochs, str(meter)))
        meter.reset(step=epoch)

        is_best = False
        metrics_valid = {}
        metrics_valid_ema = {}
        if not world_rank in [0]:
            LOGGER.info('[valid] [rank:%03d] wait master', world_rank)
        elif epoch % args.valid_skip == 0 or epoch > (epochs * 0.9):
            for name, m in [('valid', model), ('valid_ema', model_ema_eval)]:
                m.eval()
                with torch.no_grad():
                    for inputs, targets in dataloader_val:
                        num_samples = inputs.shape[0]
                        inputs = inputs.to(
                            device=device,
                            dtype=torch.half if args.half else torch.float,
                            non_blocking=True)
                        targets = targets.to(device=device, non_blocking=True)

                        logits = m(inputs)
                        meter.updates(metrics_calculator(logits, targets),
                                      n=num_samples)

                if name == 'valid':
                    metrics_valid = meter.get()
                else:
                    metrics_valid_ema = meter.get()

                print('[%s] [rank:%03d] %03d/%03d epoch | %s' %
                      (name, world_rank, epoch, epochs, str(meter)))
                is_best = best_accuracy < metrics_valid['accuracy']
                best_accuracy = max(best_accuracy, metrics_valid['accuracy'])

                meter.reset(step=epoch)
        else:
            LOGGER.info('[valid] [rank:%03d] skip master', world_rank)

        if world_rank == 0:
            throughput = (epoch +
                          1) * steps * batch * world_size * timer.throughput()
            summary_writers['train'].add_scalar('hyperparams/lr',
                                                scheduler.get_lr()[0],
                                                global_step=epoch)
            summary_writers['train'].add_scalar('performance/throughput',
                                                throughput,
                                                global_step=epoch)
            for key, value in metrics_train.items():
                summary_writers['train'].add_scalar('metrics/%s' % key,
                                                    value,
                                                    global_step=epoch)
            for key, value in metrics_valid.items():
                summary_writers['valid'].add_scalar('metrics/%s' % key,
                                                    value,
                                                    global_step=epoch)
            for key, value in metrics_valid_ema.items():
                summary_writers['valid_ema'].add_scalar('metrics/%s' % key,
                                                        value,
                                                        global_step=epoch)

            LOGGER.info(
                '[train] [rank:%03d] %03d/%03d epoch | throughput:%.4f images/sec, %.4f sec/epoch',
                world_rank, epoch, epochs, throughput,
                timer.total_time / (epoch + 1))
            skeleton.utils.save_checkpoints(
                epoch,
                '%s/checkpoints' % args.checkpoint, {
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'metrics': {
                        'train': metrics_train,
                        'valid': metrics_valid
                    }
                },
                is_best=is_best,
                keep_last=30)
Beispiel #8
0
def fetch_dataset(data_name, subset):
    dataset = {}
    print('fetching data {}...'.format(data_name))
    root = './data/{}'.format(data_name)
    if data_name in ['MNIST', 'FashionMNIST', 'SVHN']:
        dataset['train'] = eval(
            'datasets.{}(root=root, split=\'train\', subset=subset,'
            'transform=datasets.Compose(['
            'transforms.ToTensor()]))'.format(data_name))
        dataset['test'] = eval(
            'datasets.{}(root=root, split=\'test\', subset=subset,'
            'transform=datasets.Compose([transforms.ToTensor()]))'.format(
                data_name))
        config.PARAM['transform'] = {
            'train':
            datasets.Compose(
                [transforms.Resize((32, 32)),
                 transforms.ToTensor()]),
            'test':
            datasets.Compose(
                [transforms.Resize((32, 32)),
                 transforms.ToTensor()])
        }
    elif data_name == 'EMNIST':
        dataset['train'] = datasets.EMNIST(root=root,
                                           split='train',
                                           subset=subset,
                                           transform=datasets.Compose(
                                               [transforms.ToTensor()]))
        dataset['test'] = datasets.EMNIST(root=root,
                                          split='test',
                                          subset=subset,
                                          transform=datasets.Compose(
                                              [transforms.ToTensor()]))
        config.PARAM['transform'] = {
            'train': datasets.Compose([transforms.ToTensor()]),
            'test': datasets.Compose([transforms.ToTensor()])
        }
    elif data_name in ['CIFAR10', 'CIFAR100']:
        dataset['train'] = eval(
            'datasets.{}(root=root, split=\'train\', subset=subset,'
            'transform=datasets.Compose(['
            'transforms.ToTensor()]))'.format(data_name))
        dataset['test'] = eval(
            'datasets.{}(root=root, split=\'test\', subset=subset,'
            'transform=datasets.Compose([transforms.ToTensor()]))'.format(
                data_name))
        config.PARAM['transform'] = {
            'train': datasets.Compose([transforms.ToTensor()]),
            'test': datasets.Compose([transforms.ToTensor()])
        }
    elif data_name == 'ImageNet':
        dataset['train'] = datasets.ImageNet(root,
                                             split='train',
                                             subset=subset,
                                             transform=datasets.Compose(
                                                 [transforms.ToTensor()]))
        dataset['test'] = datasets.ImageNet(root,
                                            split='test',
                                            subset=subset,
                                            transform=datasets.Compose(
                                                [transforms.ToTensor()]))
        config.PARAM['transform'] = {
            'train':
            datasets.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()]),
            'test':
            datasets.Compose(
                [transforms.Resize((224, 224)),
                 transforms.ToTensor()])
        }
    elif data_name == 'Kodak':
        dataset['train'] = datasets.ImageFolder(root,
                                                transform=datasets.Compose(
                                                    [transforms.ToTensor()]))
        dataset['test'] = datasets.ImageFolder(root,
                                               transform=datasets.Compose(
                                                   [transforms.ToTensor()]))
        config.PARAM['transform'] = {
            'train': datasets.Compose([transforms.ToTensor()]),
            'test': datasets.Compose([transforms.ToTensor()])
        }
    else:
        raise ValueError('Not valid dataset name')
    dataset['train'].transform = config.PARAM['transform']['train']
    dataset['test'].transform = config.PARAM['transform']['test']
    print('data ready')
    return dataset
Beispiel #9
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    # rank, world_size = init_dist(
    #     backend='nccl', port=args.port)
    # args.rank = rank
    # args.world_size = world_size

    args.rank = 0
    args.world_size = 8

    np.random.seed(args.seed*args.rank)
    torch.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed_all(args.seed*args.rank)
    print('random seed: ', args.seed*args.rank)

    # create model
    print("=> creating model '{}'".format(args.model))
    if args.SinglePath:
        architecture = 20*[0]
        channels_scales = 20*[1.0]
        model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales)
        model.cuda()
        #broadcast_params(model)
        for v in model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        model.log_alpha.grad = torch.zeros_like(model.log_alpha)   
    
    criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda()


    wo_wd_params = []
    wo_wd_param_names = []
    network_params = []
    network_param_names = []

    for name, mod in model.named_modules():
        if isinstance(mod, nn.BatchNorm2d):
            for key, value in mod.named_parameters():
                wo_wd_param_names.append(name+'.'+key)
        
    for key, value in model.named_parameters():
        if key != 'log_alpha':
            if value.requires_grad:
                if key in wo_wd_param_names:
                    wo_wd_params.append(value)
                else:
                    network_params.append(value)
                    network_param_names.append(key)

    params = [
        {'params': network_params,
         'lr': args.base_lr,
         'weight_decay': args.weight_decay },
        {'params': wo_wd_params,
         'lr': args.base_lr,
         'weight_decay': 0.},
    ]
    param_names = [network_param_names, wo_wd_param_names]
    if args.rank == 0:
        print('>>> params w/o weight decay: ', wo_wd_param_names)

    optimizer = torch.optim.SGD(params, momentum=args.momentum)
    if args.SinglePath:
        arch_optimizer = torch.optim.Adam(
            [param for name, param in model.named_parameters() if name == 'log_alpha'],
            lr=args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=args.arch_weight_decay
        )

    # auto resume from a checkpoint
    remark = 'imagenet_'
    remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(args.base_lr)  + '_seed_' + str(args.seed)

    if args.early_fix_arch:
        remark += '_early_fix_arch'  

    if args.flops_loss:
        remark += '_flops_loss_' + str(args.flops_loss_coef)

    if args.remark != 'none':
        remark += '_'+args.remark

    args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark)
    args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark)
    generate_date = str(datetime.now().date())

    path = os.path.join(generate_date, args.save)
    if args.rank == 0:
        log_format = '%(asctime)s %(message)s'
        utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py'))
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(path, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", args)
        writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log)
    else:
        writer = None

    model_dir = path
    start_epoch = 0
    
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)

    cudnn.benchmark = True
    cudnn.enabled = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize])
    train_dataset = datasets.ImageNet(split='train', transform=transform)

    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize])
    train_dataset_wo_ms = datasets.ImageNet(split='train', transform=transform)

    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize])
    val_dataset = datasets.ImageNet(split='val', transform=transform)

    # train_sampler = DistributedSampler(train_dataset)
    # val_sampler = DistributedSampler(val_dataset)
    #
    # train_loader = DataLoader(
    #     train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=False, sampler=train_sampler)
    #
    # train_loader_wo_ms = DataLoader(
    #     train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=False, sampler=train_sampler)
    #
    # val_loader = DataLoader(
    #     val_dataset, batch_size=50, shuffle=False,
    #     num_workers=args.workers, pin_memory=False, sampler=val_sampler)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    train_loader_wo_ms = DataLoader(
        train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    val_loader = DataLoader(
        val_dataset, batch_size=50, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer, logging)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        #train_sampler.set_epoch(epoch)
        
        if args.early_fix_arch:
            if len(model.fix_arch_index.keys()) > 0:
                for key, value_lst in model.fix_arch_index.items():
                    model.log_alpha.data[key, :] = value_lst[1]
            sort_log_alpha = torch.topk(F.softmax(model.log_alpha.data, dim=-1), 2)
            argmax_index = (sort_log_alpha[0][:,0] - sort_log_alpha[0][:,1] >= 0.3)
            for id in range(argmax_index.size(0)):
                if argmax_index[id] == 1 and id not in model.fix_arch_index.keys():
                    model.fix_arch_index[id] = [sort_log_alpha[1][id,0].item(), model.log_alpha.detach().clone()[id, :]]
            
        if args.rank == 0 and args.SinglePath:
            logging.info('epoch %d', epoch)
            logging.info(model.log_alpha)         
            logging.info(F.softmax(model.log_alpha, dim=-1))         
            logging.info('flops %fM', model.cal_flops())  

        # train for one epoch
        if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms:
            train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)
        else:
            train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)


        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, logging)
        if args.gen_max_child:
            args.gen_max_child_flag = True
            prec1 = validate(val_loader, model, criterion, epoch, writer, logging)        
            args.gen_max_child_flag = False

        if args.rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(model_dir, {
                'epoch': epoch + 1,
                'model': args.model,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
def main():
    timer = skeleton.utils.Timer()
    args = parse_args()
    if args.checkpoint is None:
        raise ValueError('must be a set --checkpoint')

    log_format = '[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)03d] %(message)s'
    level = logging.DEBUG if args.debug else logging.INFO
    if not args.log_filename:
        logging.basicConfig(level=level, format=log_format, stream=sys.stderr)
    else:
        logging.basicConfig(level=level,
                            format=log_format,
                            filename=args.log_filename)
    torch.backends.cudnn.benchmark = True
    if args.seed is not None:
        skeleton.utils.set_random_seed_all(args.seed, deterministic=False)

    assert 'efficientnet' in args.architecture or 'resnet' in args.architecture
    assert args.architecture.split('-')[1] in ['b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7'] or\
           args.architecture.split('-')[1] in ['18', '34', '50', '101']

    if args.local_rank >= 0:
        logging.info('Distributed: wait dist process group:%d',
                     args.local_rank)
        dist.init_process_group(backend=args.dist_backend,
                                init_method='env://',
                                world_size=int(os.environ['WORLD_SIZE']))
        assert (int(os.environ['WORLD_SIZE']) == dist.get_world_size())
        logging.info('Distributed: success device:%d (%d/%d)', args.local_rank,
                     dist.get_rank(), dist.get_world_size())

        world_size = dist.get_world_size()
        world_rank = dist.get_rank()
    else:
        logging.info('Single proces')
        args.local_rank = 0
        world_size = 1
        world_rank = 0

    environments = skeleton.utils.Environments()
    device = torch.device('cuda', args.local_rank)
    torch.cuda.set_device(device)

    if args.batch is None:
        if 'efficientnet' in args.architecture:
            batch = 128 if 'b0' in args.architecture else 1
            batch = 96 if 'b1' in args.architecture else batch
            batch = 64 if 'b2' in args.architecture else batch
            batch = 32 if 'b3' in args.architecture else batch
            batch = 16 if 'b4' in args.architecture else batch
            batch = 8 if 'b5' in args.architecture else batch
            batch = 6 if 'b6' in args.architecture else batch
            batch = 4 if 'b7' in args.architecture else batch
            batch *= 2
        else:
            batch = 256
        batch = batch * (2 if args.half else 1)
    else:
        batch = args.batch

    LOGGER.info('environemtns\n%s', environments)
    LOGGER.info('args\n%s', args)

    total_batch = batch * world_size
    steps_per_epoch = int(1281167 / total_batch)

    if 'efficientnet' in args.architecture:
        norm_layer = torch.nn.SyncBatchNorm if world_size > 1 and args.sync_bn else torch.nn.BatchNorm2d
        input_size = efficientnet.EfficientNet.get_image_size(
            args.architecture)
        model = efficientnet.EfficientNet.from_name(
            args.architecture,
            # override_params={'batch_norm_momentum': 0.9999},
            norm_layer=norm_layer).to(device=device)

        # model.set_swish(memory_efficient=False)

        def kernel_initializer(module):
            if isinstance(module, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(module.weight,
                                              mode='fan_out',
                                              nonlinearity='relu')
            elif isinstance(module, torch.nn.Linear):
                torch.nn.init.kaiming_uniform_(module.weight,
                                               mode='fan_in',
                                               nonlinearity='linear')

        model.apply(kernel_initializer)

        epoch_scale = 1  # if args.architecture.split('-')[1] in ['b5', 'b6', 'b7'] else 2
        epochs = 350 * epoch_scale
        learning_rate = 0.256 / (4096 / total_batch)
        weight_decay = 1e-5
        # optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate,
        #                                 alpha=0.9, momentum=0.9, weight_decay=0.0,
        #                                 eps=math.sqrt(0.001))
        optimizer = skeleton.optim.RMSprop(model.parameters(),
                                           lr=learning_rate,
                                           alpha=0.9,
                                           momentum=0.9,
                                           weight_decay=0.0,
                                           eps=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=int(2.4 * epoch_scale * steps_per_epoch),
            gamma=0.97)

        criterion = skeleton.nn.CrossEntropyLabelSmooth(num_classes=1000,
                                                        epsilon=0.1,
                                                        reduction='mean')
    else:
        input_size = 224
        norm_layer = torch.nn.SyncBatchNorm if world_size > 1 and args.sync_bn else torch.nn.BatchNorm2d
        model = torchvision.models.resnet18(
            norm_layer=norm_layer) if '18' in args.architecture else None
        model = torchvision.models.resnet34(
            norm_layer=norm_layer) if '34' in args.architecture else model
        model = torchvision.models.resnet50(
            norm_layer=norm_layer) if '50' in args.architecture else model
        model = torchvision.models.resnet101(
            norm_layer=norm_layer) if '101' in args.architecture else model
        model = model.to(device=device)

        epochs = 90
        learning_rate = 0.1 / (256 / total_batch)
        weight_decay = 1e-5
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=learning_rate,
                                    momentum=0.9,
                                    weight_decay=0.0,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, [l * steps_per_epoch for l in [30, 60, 80]], gamma=0.1)

        criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    scheduler = GradualWarmup(optimizer, scheduler, steps=5 * steps_per_epoch)
    metricer = skeleton.nn.AccuracyMany((1, 5))

    # profiler = skeleton.nn.Profiler(model)
    # params = profiler.params()
    # flops = profiler.flops(torch.ones(1, 3, input_size, input_size, dtype=torch.float, device=device))

    # LOGGER.info('arechitecture\n%s\ninput:%d\nprarms:%.2fM\nGFLOPs:%.3f', args.architecture, input_size, params / (1024 * 1024), flops / (1024 * 1024 * 1024))
    LOGGER.info('arechitecture:%s\ninput:%d', args.architecture, input_size)
    LOGGER.info('optimizers\nloss:%s\noptimizer:%s\nscheduler:%s',
                str(criterion), str(optimizer), str(scheduler))
    LOGGER.info(
        'hyperparams\nbatch:%d\ninput_size:%d\nsteps_per_epoch:%d\nlearning_rate_init:%.4f',
        batch, input_size, steps_per_epoch, learning_rate)

    # dataset = skeleton.data.ImageNet(root=args.datapath + '/imagenet', split='train', transform=torchvision.transforms.Compose([
    dataset = datasets.ImageNetCovered(
        split='train',
        special_transform=datasets.imagenet.RandomCropBBox(
            min_object_covered=0.1,
            scale=(0.08, 1.0),
            ratio=(3. / 4., 4. / 3.)),
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((input_size, input_size),
                                          interpolation=Image.BICUBIC),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
        ]))
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=world_rank)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch,
                                             sampler=sampler,
                                             num_workers=args.workers,
                                             drop_last=True,
                                             pin_memory=True)
    steps = len(dataloader)

    resize_image = input_size if 'efficientnet' in args.architecture else int(
        input_size * 1.14)
    # dataset = torchvision.datasets.ImageNet(root=args.datapath + '/imagenet', split='val', transform=torchvision.transforms.Compose([
    dataset = datasets.ImageNet(
        split='val',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(resize_image,
                                          interpolation=Image.BICUBIC),
            torchvision.transforms.CenterCrop(input_size),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
        ]))
    dataloader_val = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 drop_last=False,
                                                 pin_memory=True)

    if args.half:
        for module in model.modules():
            if not isinstance(module, torch.nn.BatchNorm2d):
                module.half()
    if world_size > 1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
        for param in model.parameters():
            dist.broadcast(param.data, 0)
    torch.cuda.synchronize()

    loss_scaler = 1.0 if not args.half else 1024.0

    params_without_bn = [
        params for name, params in model.named_parameters()
        if not ('_bn' in name or '.bn' in name)
    ]

    best_accuracy = 0.0
    timer('init', reset_step=True, exclude_total=True)
    for epoch in range(epochs):
        model.train()
        sampler.set_epoch(epoch)  # re-shuffled dataset per node
        loss_sum = torch.zeros(1,
                               device=device,
                               dtype=torch.half if args.half else torch.float)
        accuracy_top1_sum = torch.zeros(1, device=device)
        accuracy_top5_sum = torch.zeros(1, device=device)
        for step, (inputs, targets) in enumerate(dataloader):
            timer('init', reset_step=True, exclude_total=True)
            inputs = inputs.to(device=device,
                               dtype=torch.half if args.half else torch.float,
                               non_blocking=True)
            targets = targets.to(device=device, non_blocking=True)

            logits = model(inputs).to(dtype=torch.float)
            loss = criterion(logits, targets)

            # l2 regularizer \wo batchnorm params
            loss_weight_l2 = sum(
                [p.to(dtype=torch.float).norm(2) for p in params_without_bn])
            loss = loss + (weight_decay * loss_weight_l2)
            timer('forward')

            optimizer.zero_grad()
            if loss_scaler == 1.0:
                loss.backward()
            else:
                (loss * loss_scaler).backward()
                for param in model.parameters():
                    param.grad.data /= loss_scaler
            timer('backward')

            optimizer.step()
            scheduler.step()
            timer('optimize')

            with torch.no_grad():
                accuracies = metricer(logits, targets)

            loss_sum += loss.detach()
            accuracy_top1_sum += accuracies[0].detach()
            accuracy_top5_sum += accuracies[1].detach()

            if step % (steps // 100) == 0:
                LOGGER.info(
                    '[train] [rank:%03d] %03d/%03d epoch (%02d%%) | loss:%.4f, top1:%.4f, top5:%.4f | lr:%.4f',
                    world_rank, epoch, epochs, 100.0 * step / steps,
                    loss_sum.item() / (step + 1),
                    accuracy_top1_sum.item() / (step + 1),
                    accuracy_top5_sum.item() / (step + 1),
                    scheduler.get_lr()[0])
            timer('remain')

        metrics = {
            'loss': loss_sum.item() / steps,
            'accuracy_top1': accuracy_top1_sum.item() / steps,
            'accuracy_top5': accuracy_top5_sum.item() / steps,
        }
        logging.info(
            '[train] [rank:%03d] %03d/%03d epoch | loss:%.5f, top1:%.4f, top5:%.4f',
            world_rank, epoch, epochs, metrics['loss'],
            metrics['accuracy_top1'], metrics['accuracy_top5'])

        is_best = False
        metrics_train = copy.deepcopy(metrics)
        if not world_rank == 0:
            LOGGER.info('[valid] [rank:%03d] wait master', world_rank)
        elif epoch % args.valid_skip == 0 or epoch > (epochs * 0.9):
            model.eval()

            num_samples_sum = 0
            loss_sum = torch.zeros(
                1,
                device=device,
                dtype=torch.half if args.half else torch.float)
            accuracy_top1_sum = torch.zeros(1, device=device)
            accuracy_top5_sum = torch.zeros(1, device=device)
            with torch.no_grad():
                for inputs, targets in dataloader_val:
                    num_sampels = inputs.shape[0]
                    inputs = inputs.to(
                        device=device,
                        dtype=torch.half if args.half else torch.float,
                        non_blocking=True)
                    targets = targets.to(device=device, non_blocking=True)

                    logits = model(inputs)
                    loss = criterion(logits, targets)
                    accuracies = metricer(logits, targets)

                    num_samples_sum += num_sampels
                    loss_sum += loss.detach() * num_sampels
                    accuracy_top1_sum += accuracies[0].detach() * num_sampels
                    accuracy_top5_sum += accuracies[1].detach() * num_sampels

            metrics = {
                'loss': loss_sum.item() / num_samples_sum,
                'accuracy_top1': accuracy_top1_sum.item() / num_samples_sum,
                'accuracy_top5': accuracy_top5_sum.item() / num_samples_sum,
            }
            logging.info(
                '[valid] [rank:%03d] %02d/%02d epoch | loss:%.5f, top1:%.4f, top5:%.4f',
                world_rank, epoch, epochs, metrics['loss'],
                metrics['accuracy_top1'], metrics['accuracy_top5'])
            is_best = best_accuracy < metrics['accuracy_top1']
            best_accuracy = max(best_accuracy, metrics['accuracy_top1'])
        else:
            LOGGER.info('[valid] [rank:%03d] skip master', world_rank)

        if world_rank == 0:
            LOGGER.info(
                '[train] [rank:%03d] %03d/%03d epoch | throughput:%.4f images/sec, %.4f sec/epoch',
                world_rank, epoch, epochs,
                (epoch + 1) * steps * batch * world_size * timer.throughput(),
                timer.total_time / (epoch + 1))
            skeleton.utils.save_checkpoints(
                epoch,
                args.checkpoint, {
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'metrics': {
                        'train': metrics_train,
                        'valid': metrics
                    }
                },
                is_best=is_best,
                keep_last=30)
Beispiel #11
0
def test_imagenet_dataset_count():
    ds = datasets.ImageNet('train')
    assert len(ds) == 1281167

    ds = datasets.ImageNet('val')
    assert len(ds) == 50000
Beispiel #12
0
def test_imagenet_val_loading_time():
    t0 = time.time()
    _ = datasets.ImageNet('val')
    assert (time.time() - t0) < 10.0
Beispiel #13
0
def test_imagenet_train_loading_time():
    t0 = time.time()
    _ = datasets.ImageNet('train')
    assert (time.time() - t0) < 10.0
Beispiel #14
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements
    if rank == 0:
        save_dir = os.path.join(args.save, args.arch,
                                "b{}".format(args.batch_size * world_size))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)

    if world_size > 1:
        # Initialize distributed process group
        logging.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch)

    if rank == 0:
        prefixs = ['train', 'valid']
        writers = {
            prefix: SummaryWriter(os.path.join(args.output, prefix))
            for prefix in prefixs
        }

    model = getattr(M, args.arch)()
    step_start = 0
    # if args.model:
    #     logging.info("load weights from %s", args.model)
    #     model.load_state_dict(mge.load(args.model))
    #     step_start = int(args.model.split("-")[1].split(".")[0])

    optimizer = optim.SGD(
        get_parameters(model),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # Define train and valid graph
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logging.info("preparing dataset..")

    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_dataset = datasets.ImageNet(split='train', transform=transform)
    train_sampler = torch.utils.data.RandomSampler(train_dataset)
    train_queue = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              shuffle=False,
                                              drop_last=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    train_queue = iter(train_queue)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    valid_dataset = datasets.ImageNet(split='val', transform=transform)
    valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
    valid_queue = torch.utils.data.DataLoader(valid_dataset,
                                              batch_size=100,
                                              sampler=valid_sampler,
                                              shuffle=False,
                                              drop_last=False,
                                              num_workers=args.workers)

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()

    best_valid_acc = 0
    for step in range(step_start, args.steps + 1):
        # Linear learning rate decay
        decay = 1.0
        decay = 1 - float(step) / args.steps if step < args.steps else 0
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.learning_rate * decay

        image, label = next(train_queue)
        time_data = time.time() - t
        # image = image.astype("float32")
        # label = label.astype("int32")

        n = image.shape[0]

        optimizer.zero_grad()
        loss, acc1, acc5 = train_func(image, label)
        optimizer.step()

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        time_iter = time.time() - t
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logging.info(
                "TRAIN Iter %06d: lr = %f,\tloss = %f,\twc_loss = 1,\tTop-1 err = %f,\tTop-5 err = %f,\tdata_time = %f,\ttrain_time = %f,\tremain_hours=%f",
                step,
                args.learning_rate * decay,
                float(objs.__str__().split()[1]),
                1 - float(top1.__str__().split()[1]) / 100,
                1 - float(top5.__str__().split()[1]) / 100,
                time_data,
                time_iter - time_data,
                time_iter * (args.steps - step) / 3600,
            )

            writers['train'].add_scalar('loss',
                                        float(objs.__str__().split()[1]),
                                        global_step=step)
            writers['train'].add_scalar('top1_err',
                                        1 -
                                        float(top1.__str__().split()[1]) / 100,
                                        global_step=step)
            writers['train'].add_scalar('top5_err',
                                        1 -
                                        float(top5.__str__().split()[1]) / 100,
                                        global_step=step)

            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()

        if step % 10000 == 0 and step != 0:
            loss, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logging.info(
                "TEST Iter %06d: loss = %f,\tTop-1 err = %f,\tTop-5 err = %f",
                step, loss, 1 - valid_acc / 100, 1 - valid_acc5 / 100)

            is_best = valid_acc > best_valid_acc
            best_valid_acc = max(valid_acc, best_valid_acc)

            if rank == 0:
                writers['valid'].add_scalar('loss', loss, global_step=step)
                writers['valid'].add_scalar('top1_err',
                                            1 - valid_acc / 100,
                                            global_step=step)
                writers['valid'].add_scalar('top5_err',
                                            1 - valid_acc5 / 100,
                                            global_step=step)

                logging.info("SAVING %06d", step)

                save_checkpoint(
                    save_dir, {
                        'step': step + 1,
                        'model': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_valid_acc,
                        'optimizer': optimizer.state_dict(),
                    }, is_best)