Exemplo n.º 1
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size

    # create model
    print("=> creating model '{}'".format(args.model))
    if 'resnetv1sn' in args.model:
        model = models.__dict__[args.model](
            using_moving_average=args.using_moving_average,
            using_bn=args.using_bn,
            last_gamma=args.last_gamma)
    else:
        model = models.__dict__[args.model](
            using_moving_average=args.using_moving_average,
            using_bn=args.using_bn)

    model.cuda()
    broadcast_params(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # auto resume from a checkpoint
    model_dir = args.model_dir
    start_epoch = 0
    if args.rank == 0 and not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir,
                                             model,
                                             optimizer=optimizer)
    if args.rank == 0:
        writer = SummaryWriter(model_dir)
    else:
        writer = None

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImagenetDataset(
        args.train_root, args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ColorAugmentation(),
            normalize,
        ]))
    val_dataset = ImagenetDataset(
        args.val_root, args.val_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size // args.world_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size // args.world_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, lr_scheduler, epoch,
              writer)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer)

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                model_dir, {
                    'epoch': epoch + 1,
                    'model': args.model,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Exemplo n.º 2
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(
        backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size


    np.random.seed(args.seed*args.rank)
    torch.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed_all(args.seed*args.rank)

    # create model
    print("=> creating model '{}'".format(args.model))
    if args.SinglePath:
        architecture = 20*[0]
        channels_scales = 20*[1.0]
        #load derived child network
        log_alpha = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device()))['state_dict']['log_alpha']
        weights = torch.zeros_like(log_alpha).scatter_(1, torch.argmax(log_alpha, dim = -1).view(-1,1), 1)
        model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales, weights=weights)
        model.cuda()
        broadcast_params(model)
        for v in model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        model.log_alpha.grad = torch.zeros_like(model.log_alpha)   
        if not args.retrain:
            load_state_ckpt(args.checkpoint_path, model)
            checkpoint = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device()))
            args.base_lr = checkpoint['optimizer']['param_groups'][0]['lr']
        if args.reset_bn_stat:
            model._reset_bn_running_stats()

    # define loss function (criterion) and optimizer
    criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda()

    wo_wd_params = []
    wo_wd_param_names = []
    network_params = []
    network_param_names = []

    for name, mod in model.named_modules():
        #if isinstance(mod, (nn.BatchNorm2d, SwitchNorm2d)):
        if isinstance(mod, nn.BatchNorm2d):
            for key, value in mod.named_parameters():
                wo_wd_param_names.append(name+'.'+key)
        
    for key, value in model.named_parameters():
        if key != 'log_alpha':
            if value.requires_grad:
                if key in wo_wd_param_names:
                    wo_wd_params.append(value)
                else:
                    network_params.append(value)
                    network_param_names.append(key)

    params = [
        {'params': network_params,
         'lr': args.base_lr,
         'weight_decay': args.weight_decay },
        {'params': wo_wd_params,
         'lr': args.base_lr,
         'weight_decay': 0.},
    ]
    param_names = [network_param_names, wo_wd_param_names]
    if args.rank == 0:
        print('>>> params w/o weight decay: ', wo_wd_param_names)
    optimizer = torch.optim.SGD(params, momentum=args.momentum)
    arch_optimizer=None

    # auto resume from a checkpoint
    remark = 'imagenet_'
    remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(float("{0:.2f}".format(args.base_    lr))) + '_seed_' + str(args.seed)

    if args.remark != 'none':
        remark += '_'+args.remark

    args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark)
    args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark)
    generate_date = str(datetime.now().date())

    path = os.path.join(generate_date, args.save)
    if args.rank == 0:
        log_format = '%(asctime)s %(message)s'
        utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py'))
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(path, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", args)
        writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log)
    else:
        writer = None

    #model_dir = args.model_dir
    model_dir = path
    start_epoch = 0
    
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_dataset_wo_ms = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = ImagenetDataset(
        args.val_root,
        args.val_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    train_loader_wo_ms = DataLoader(
        train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    val_loader = DataLoader(
        val_dataset, batch_size=50, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer, logging)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)
        
        if args.rank == 0 and args.SinglePath:
            logging.info('epoch %d', epoch)
        
        # evaluate on validation set after loading the model
        if epoch == 0 and not args.reset_bn_stat:
            prec1 = validate(val_loader, model, criterion, epoch, writer, logging)
       
         # train for one epoch
        if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms and args.retrain:
            train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)
        else:
            train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)


        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, logging)

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(model_dir, {
                'epoch': epoch + 1,
                'model': args.model,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Exemplo n.º 3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(
        backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size

    np.random.seed(args.seed*args.rank)
    torch.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed_all(args.seed*args.rank)
    print('random seed: ', args.seed*args.rank)

    # create model
    print("=> creating model '{}'".format(args.model))
    if args.SinglePath:
        architecture = 20*[0]
        channels_scales = 20*[1.0]
        model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales)
        model.cuda()
        broadcast_params(model)
        for v in model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        model.log_alpha.grad = torch.zeros_like(model.log_alpha)   
    
    criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda()


    wo_wd_params = []
    wo_wd_param_names = []
    network_params = []
    network_param_names = []

    for name, mod in model.named_modules():
        if isinstance(mod, nn.BatchNorm2d):
            for key, value in mod.named_parameters():
                wo_wd_param_names.append(name+'.'+key)
        
    for key, value in model.named_parameters():
        if key != 'log_alpha':
            if value.requires_grad:
                if key in wo_wd_param_names:
                    wo_wd_params.append(value)
                else:
                    network_params.append(value)
                    network_param_names.append(key)

    params = [
        {'params': network_params,
         'lr': args.base_lr,
         'weight_decay': args.weight_decay },
        {'params': wo_wd_params,
         'lr': args.base_lr,
         'weight_decay': 0.},
    ]
    param_names = [network_param_names, wo_wd_param_names]
    if args.rank == 0:
        print('>>> params w/o weight decay: ', wo_wd_param_names)

    optimizer = torch.optim.SGD(params, momentum=args.momentum)
    if args.SinglePath:
        arch_optimizer = torch.optim.Adam(
            [param for name, param in model.named_parameters() if name == 'log_alpha'],
            lr=args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=args.arch_weight_decay
        )

    # auto resume from a checkpoint
    remark = 'imagenet_'
    remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(args.base_lr)  + '_seed_' + str(args.seed) + '_pretrain_' + str(args.pretrain_epoch)

    if args.early_fix_arch:
        remark += '_early_fix_arch'  

    if args.flops_loss:
        remark += '_flops_loss_' + str(args.flops_loss_coef)

    if args.remark != 'none':
        remark += '_'+args.remark

    args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark)
    args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark)
    generate_date = str(datetime.now().date())

    path = os.path.join(generate_date, args.save)
    if args.rank == 0:
        log_format = '%(asctime)s %(message)s'
        utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py'))
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(path, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", args)
        writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log)
    else:
        writer = None

    model_dir = path
    start_epoch = 0
    
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)

    cudnn.benchmark = True
    cudnn.enabled = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_dataset_wo_ms = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = ImagenetDataset(
        args.val_root,
        args.val_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    train_loader_wo_ms = DataLoader(
        train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    val_loader = DataLoader(
        val_dataset, batch_size=50, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer, logging)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, 85):
        train_sampler.set_epoch(epoch)
        
        if args.early_fix_arch:
            if len(model.fix_arch_index.keys()) > 0:
                for key, value_lst in model.fix_arch_index.items():
                    model.log_alpha.data[key, :] = value_lst[1]
            sort_log_alpha = torch.topk(F.softmax(model.log_alpha.data, dim=-1), 2)
            argmax_index = (sort_log_alpha[0][:,0] - sort_log_alpha[0][:,1] >= 0.3)
            for id in range(argmax_index.size(0)):
                if argmax_index[id] == 1 and id not in model.fix_arch_index.keys():
                    model.fix_arch_index[id] = [sort_log_alpha[1][id,0].item(), model.log_alpha.detach().clone()[id, :]]
            
        if args.rank == 0 and args.SinglePath:
            logging.info('epoch %d', epoch)
            logging.info(model.log_alpha)         
            logging.info(F.softmax(model.log_alpha, dim=-1))         
            logging.info('flops %fM', model.cal_flops())  

        # train for one epoch
        if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms:
            train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)
        else:
            train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)


        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, logging)
        if args.gen_max_child:
            args.gen_max_child_flag = True
            prec1 = validate(val_loader, model, criterion, epoch, writer, logging)        
            args.gen_max_child_flag = False

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(model_dir, {
                'epoch': epoch + 1,
                'model': args.model,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu

    if args.distributed:
        print(args.backend, args.world_size, args.rank)
        dist.init_process_group(backend=args.backend,
                                init_method='tcp://127.0.0.1:6668',
                                world_size=args.world_size,
                                rank=args.rank)
    print('Enabled distributed training.')
    # create model
    print("=> creating model '{}'".format(args.model))
    model = models.__dict__[args.model](N=args.N, M=args.M)
    torch.cuda.set_device(args.gpu)

    ipClass = PruningMethodTransposableBlockL1(block_size=args.M, topk=args.N)
    if args.load_mask:
        load_state_and_masks(model, args)
        print("Masks loaded!")
    else:
        for n, m in model.named_modules():
            if isinstance(m, SparseConvTranspose) or isinstance(
                    m, SparseLinearTranspose):
                #    m.maskBuff.data = ipClass.compute_mask(m.weight, torch.ones_like(m.weight))
                setattr(
                    m.weight, "mask",
                    ipClass.compute_mask(m.weight, torch.ones_like(m.weight)))

        if args.save_mask:
            save_masks(model, args)
            print("Masks saved!")

    model.cuda(args.gpu)
    #args.batch_size = int(args.batch_size / ngpus_per_node)
    args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
    #broadcast_params(model)
    print(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    if args.sparse_optimizer:
        optimizer = sparse_optimizer.SGD(model.parameters(),
                                         args.base_lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.base_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    # auto resume from a checkpoint
    model_dir = args.model_dir
    start_epoch = 0
    best_prec1 = 0
    if args.rank == 0 and not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir,
                                             model,
                                             optimizer=optimizer)
    if args.rank == 0 or not args.distributed:
        writer = SummaryWriter(model_dir)
    else:
        writer = None

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        args.train_root,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ColorAugmentation(),
            normalize,
        ]))
    val_dataset = datasets.ImageFolder(
        args.val_root,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = DistributedSampler(train_dataset)
        val_sampler = DistributedSampler(val_dataset)
    else:
        train_sampler = None
        val_sampler = None

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size // args.world_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size // args.world_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, lr_scheduler, epoch,
              writer, args)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, args)

        if args.rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                model_dir, {
                    'epoch': epoch + 1,
                    'model': args.model,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Exemplo n.º 5
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    rank, world_size = init_dist(backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size

    # create model

    model = ArcFaceWithLoss(args.backbone, args.class_num, args.norm_func,
                            args.embedding_size, args.use_se)

    model.cuda()
    broadcast_params(model)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # auto resume from a checkpoint
    model_dir = args.model_dir
    start_epoch = 0
    if args.rank == 0 and not os.path.exists(model_dir):
        os.makedirs(model_dir)

    best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)
    if args.rank == 0:
        writer = SummaryWriter(model_dir)
    else:
        writer = None

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_dataset = FaceDataset(
        True, args,
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = BigdataSampler(
        train_dataset,
        num_sub_epochs=2,
        finegrain_factor=10000,
        seed=1000,
    )

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size // args.world_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        train(train_loader, model, optimizer, lr_scheduler, epoch, writer)
        if rank == 0:
            save_checkpoint(
                model_dir, {
                    'epoch': epoch + 1,
                    'model': args.backbone,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, False)
Exemplo n.º 6
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements
    if rank == 0:
        save_dir = os.path.join(args.save, args.arch,
                                "b{}".format(args.batch_size * world_size))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)

    if world_size > 1:
        # Initialize distributed process group
        logging.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch)

    if rank == 0:
        prefixs = ['train', 'valid']
        writers = {
            prefix: SummaryWriter(os.path.join(args.output, prefix))
            for prefix in prefixs
        }

    model = getattr(M, args.arch)()
    step_start = 0
    # if args.model:
    #     logging.info("load weights from %s", args.model)
    #     model.load_state_dict(mge.load(args.model))
    #     step_start = int(args.model.split("-")[1].split(".")[0])

    optimizer = optim.SGD(
        get_parameters(model),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # Define train and valid graph
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logging.info("preparing dataset..")

    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_dataset = datasets.ImageNet(split='train', transform=transform)
    train_sampler = torch.utils.data.RandomSampler(train_dataset)
    train_queue = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              shuffle=False,
                                              drop_last=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    train_queue = iter(train_queue)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    valid_dataset = datasets.ImageNet(split='val', transform=transform)
    valid_sampler = torch.utils.data.SequentialSampler(valid_dataset)
    valid_queue = torch.utils.data.DataLoader(valid_dataset,
                                              batch_size=100,
                                              sampler=valid_sampler,
                                              shuffle=False,
                                              drop_last=False,
                                              num_workers=args.workers)

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()

    best_valid_acc = 0
    for step in range(step_start, args.steps + 1):
        # Linear learning rate decay
        decay = 1.0
        decay = 1 - float(step) / args.steps if step < args.steps else 0
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.learning_rate * decay

        image, label = next(train_queue)
        time_data = time.time() - t
        # image = image.astype("float32")
        # label = label.astype("int32")

        n = image.shape[0]

        optimizer.zero_grad()
        loss, acc1, acc5 = train_func(image, label)
        optimizer.step()

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        time_iter = time.time() - t
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logging.info(
                "TRAIN Iter %06d: lr = %f,\tloss = %f,\twc_loss = 1,\tTop-1 err = %f,\tTop-5 err = %f,\tdata_time = %f,\ttrain_time = %f,\tremain_hours=%f",
                step,
                args.learning_rate * decay,
                float(objs.__str__().split()[1]),
                1 - float(top1.__str__().split()[1]) / 100,
                1 - float(top5.__str__().split()[1]) / 100,
                time_data,
                time_iter - time_data,
                time_iter * (args.steps - step) / 3600,
            )

            writers['train'].add_scalar('loss',
                                        float(objs.__str__().split()[1]),
                                        global_step=step)
            writers['train'].add_scalar('top1_err',
                                        1 -
                                        float(top1.__str__().split()[1]) / 100,
                                        global_step=step)
            writers['train'].add_scalar('top5_err',
                                        1 -
                                        float(top5.__str__().split()[1]) / 100,
                                        global_step=step)

            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()

        if step % 10000 == 0 and step != 0:
            loss, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logging.info(
                "TEST Iter %06d: loss = %f,\tTop-1 err = %f,\tTop-5 err = %f",
                step, loss, 1 - valid_acc / 100, 1 - valid_acc5 / 100)

            is_best = valid_acc > best_valid_acc
            best_valid_acc = max(valid_acc, best_valid_acc)

            if rank == 0:
                writers['valid'].add_scalar('loss', loss, global_step=step)
                writers['valid'].add_scalar('top1_err',
                                            1 - valid_acc / 100,
                                            global_step=step)
                writers['valid'].add_scalar('top5_err',
                                            1 - valid_acc5 / 100,
                                            global_step=step)

                logging.info("SAVING %06d", step)

                save_checkpoint(
                    save_dir, {
                        'step': step + 1,
                        'model': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_valid_acc,
                        'optimizer': optimizer.state_dict(),
                    }, is_best)