def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    print('rank: {} / {}'.format(args.rank, args.world_size))
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)
    # init the args
    global best_pred, acclist_train, acclist_val

    if args.gpu == 0:
        print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    # init dataloader
    transform_train, transform_val = encoding.transforms.get_transform(
        args.dataset, args.base_size, args.crop_size, args.rand_aug)
    trainset = encoding.datasets.get_dataset(
        args.dataset,
        root=os.path.expanduser('~/.encoding/data'),
        transform=transform_train,
        split='train')
    train_extraset = encoding.datasets.get_dataset(
        args.dataset,
        root=os.path.expanduser('~/.encoding/data'),
        transform=transform_train,
        split='train_extra')
    valset = encoding.datasets.get_dataset(
        args.dataset,
        root=os.path.expanduser('~/.encoding/data'),
        transform=transform_val,
        split='val')
    if args.trainval:
        trainset = torch.utils.data.ConcatDataset(
            [trainset, train_extraset, valset])
    else:
        trainset = torch.utils.data.ConcatDataset([trainset, train_extraset])

    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        valset, shuffle=False)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    # init the model
    model_kwargs = {}
    if args.pretrained:
        model_kwargs['pretrained'] = True

    if args.final_drop > 0.0:
        model_kwargs['final_drop'] = args.final_drop

    if args.dropblock_prob > 0.0:
        model_kwargs['dropblock_prob'] = args.dropblock_prob

    if args.last_gamma:
        model_kwargs['last_gamma'] = True

    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg

    model = encoding.models.get_model(args.model, **model_kwargs)
    model.fc = nn.Linear(2048, 19, bias=True)

    if args.dropblock_prob > 0.0:
        from functools import partial
        from encoding.nn import reset_dropblock
        nr_iters = (args.epochs - args.warmup_epochs) * len(train_loader)
        apply_drop_prob = partial(reset_dropblock,
                                  args.warmup_epochs * len(train_loader),
                                  nr_iters, 0.0, args.dropblock_prob)
        model.apply(apply_drop_prob)

    if args.gpu == 0:
        print(model)

    if args.mixup > 0:
        train_loader = MixUpWrapper(args.mixup, 1000, train_loader, args.gpu)
        criterion = NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        criterion = LabelSmoothing(args.label_smoothing)
    else:
        criterion = nn.BCEWithLogitsLoss()

    model.cuda(args.gpu)
    criterion.cuda(args.gpu)
    model = DistributedDataParallel(model, device_ids=[args.gpu])

    # criterion and optimizer
    if args.no_bn_wd:
        parameters = model.named_parameters()
        param_dict = {}
        for k, v in parameters:
            param_dict[k] = v
        bn_params = [
            v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)
        ]
        rest_params = [
            v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)
        ]
        if args.gpu == 0:
            print(" Weight decay NOT applied to BN parameters ")
            print(
                f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}'
            )
        optimizer = torch.optim.SGD([{
            'params': bn_params,
            'weight_decay': 0
        }, {
            'params': rest_params,
            'weight_decay': args.weight_decay
        }],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    # check point
    if args.resume is not None:
        if os.path.isfile(args.resume):
            if args.gpu == 0:
                print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint[
                'epoch'] + 1 if args.start_epoch == 0 else args.start_epoch
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if args.gpu == 0:
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            raise RuntimeError ("=> no resume checkpoint found at '{}'".\
                format(args.resume))
    scheduler = LR_Scheduler(args.lr_scheduler,
                             base_lr=args.lr,
                             num_epochs=args.epochs,
                             iters_per_epoch=len(train_loader),
                             warmup_epochs=args.warmup_epochs)

    def train(epoch):
        train_sampler.set_epoch(epoch)
        model.train()
        losses = AverageMeter()
        top1 = AverageMeter()
        global best_pred, acclist_train
        for batch_idx, (data, target) in enumerate(train_loader):
            scheduler(optimizer, batch_idx, epoch, best_pred)
            if not args.mixup:
                data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if not args.mixup:
                acc1 = accuracy_multilabel(output, target)
                top1.update(acc1, data.size(0))

            losses.update(loss.item(), data.size(0))
            if batch_idx % 100 == 0 and args.gpu == 0:
                if args.mixup:
                    print('Batch: %d| Loss: %.3f' % (batch_idx, losses.avg))
                else:
                    print('Batch: %d| Loss: %.3f | Top1: %.3f' %
                          (batch_idx, losses.avg, top1.avg))

        acclist_train += [top1.avg]

    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        global best_pred, acclist_train, acclist_val
        is_best = False
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            with torch.no_grad():
                output = model(data)
                acc1 = accuracy_multilabel(output, target)
                top1.update(acc1, data.size(0))

        # sum all
        sum1, cnt1 = torch_dist_sum(args.gpu, top1.sum, top1.count)

        if args.eval:
            if args.gpu == 0:
                top1_acc = sum(sum1) / sum(cnt1)
                print('Validation: Top1: %.3f' % (top1_acc))
            return

        if args.gpu == 0:
            top1_acc = sum(sum1) / sum(cnt1)
            print('Validation: Top1: %.3f' % (top1_acc))

            # save checkpoint
            acclist_val += [top1_acc]
            if top1_acc > best_pred:
                best_pred = top1_acc
                is_best = True
            encoding.utils.save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_pred': best_pred,
                    'acclist_train': acclist_train,
                    'acclist_val': acclist_val,
                },
                args=args,
                is_best=is_best)

    if args.export:
        if args.gpu == 0:
            torch.save(model.module.state_dict(), args.export + '.pth')
        return

    if args.eval:
        validate(args.start_epoch)
        return

    for epoch in range(args.start_epoch, args.epochs):
        tic = time.time()
        train(epoch)
        if epoch % 10 == 0:  # or epoch == args.epochs-1:
            validate(epoch)
        elapsed = time.time() - tic
        if args.gpu == 0:
            print(f'Epoch: {epoch}, Time cost: {elapsed}')

    if args.gpu == 0:
        encoding.utils.save_checkpoint(
            {
                'epoch': args.epochs - 1,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            },
            args=args,
            is_best=False)
Пример #2
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    # model name for checkpoint
    args.model = "{}-{}".format(
        args.arch,
        os.path.splitext(os.path.basename(args.config_file))[0])
    if args.gpu == 0:
        print('model:', args.model)
    print('rank: {} / {}'.format(args.rank, args.world_size))
    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)
    # init the args
    global best_pred, acclist_train, acclist_val

    if args.gpu == 0:
        print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    # init dataloader
    transform_train, transform_val = encoding.transforms.get_transform(
        args.dataset, args.base_size, args.crop_size)
    if args.auto_policy is not None:
        print(f'Using auto_policy: {args.auto_policy}')
        from augment import Augmentation
        auto_policy = Augmentation(at.load(args.auto_policy))
        transform_train.transforms.insert(0, auto_policy)

    trainset = encoding.datasets.get_dataset(args.dataset,
                                             root=args.data_dir,
                                             transform=transform_train,
                                             train=True,
                                             download=True)
    valset = encoding.datasets.get_dataset(args.dataset,
                                           root=args.data_dir,
                                           transform=transform_val,
                                           train=False,
                                           download=True)

    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        valset, shuffle=False)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    # init the model
    arch = importlib.import_module('arch.' + args.arch)
    model = arch.config_network(args.config_file)
    if args.gpu == 0:
        print(model)

    if args.mixup > 0:
        train_loader = MixUpWrapper(args.mixup, 1000, train_loader, args.gpu)
        criterion = NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        criterion = LabelSmoothing(args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    model.cuda(args.gpu)
    criterion.cuda(args.gpu)
    # criterion and optimizer
    if args.no_bn_wd:
        parameters = model.named_parameters()
        param_dict = {}
        for k, v in parameters:
            param_dict[k] = v
        bn_params = [
            v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)
        ]
        rest_params = [
            v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)
        ]
        if args.gpu == 0:
            print(" Weight decay NOT applied to BN parameters ")
            print(
                f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}'
            )
        optimizer = torch.optim.SGD([{
            'params': bn_params,
            'weight_decay': 0
        }, {
            'params': rest_params,
            'weight_decay': args.wd
        }],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)
    if args.amp:
        #optimizer = amp_handle.wrap_optimizer(optimizer)
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
        #from apex import amp
        DDP = apex.parallel.DistributedDataParallel
        model = DDP(model, delay_allreduce=True)
    else:
        DDP = DistributedDataParallel
        model = DDP(model, device_ids=[args.gpu])

    # check point
    if args.resume is not None:
        if os.path.isfile(args.resume):
            if args.gpu == 0:
                print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint[
                'epoch'] + 1 if args.start_epoch == 0 else args.start_epoch
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if args.amp:
                amp.load_state_dict(checkpoint['amp'])
            if args.gpu == 0:
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            raise RuntimeError ("=> no resume checkpoint found at '{}'".\
                format(args.resume))
    scheduler = LR_Scheduler(args.lr_scheduler,
                             base_lr=args.lr,
                             num_epochs=args.epochs,
                             iters_per_epoch=len(train_loader),
                             warmup_epochs=args.warmup_epochs)

    def train(epoch):
        train_sampler.set_epoch(epoch)
        model.train()
        losses = AverageMeter()
        top1 = AverageMeter()
        global best_pred, acclist_train
        tic = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            scheduler(optimizer, batch_idx, epoch, best_pred)
            if not args.mixup:
                data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            if not args.mixup:
                acc1 = accuracy(output, target, topk=(1, ))
                top1.update(acc1[0], data.size(0))

            losses.update(loss.item(), data.size(0))
            if batch_idx % 100 == 0 and args.gpu == 0:
                iter_per_sec = 100.0 / (time.time() -
                                        tic) if batch_idx != 0 else 1.0 / (
                                            time.time() - tic)
                tic = time.time()
                if args.mixup:
                    #print('Batch: %d| Loss: %.3f'%(batch_idx, losses.avg))
                    print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Train loss: {:.3f}'. \
                          format(epoch, batch_idx, iter_per_sec, losses.avg.item()))
                else:
                    #print('Batch: %d| Loss: %.3f | Top1: %.3f'%(batch_idx, losses.avg, top1.avg))
                    print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Top1: {:.3f}'. \
                          format(epoch, batch_idx, iter_per_sec, top1.avg.item()))

        acclist_train += [top1.avg]

    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        top5 = AverageMeter()
        global best_pred, acclist_train, acclist_val
        is_best = False
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.cuda(args.gpu), target.cuda(args.gpu)
            with torch.no_grad():
                output = model(data)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))

        # sum all
        sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count,
                                                top5.sum, top5.count)

        if args.eval:
            if args.gpu == 0:
                top1_acc = sum(sum1) / sum(cnt1)
                top5_acc = sum(sum5) / sum(cnt5)
                print('Validation: Top1: %.3f | Top5: %.3f' %
                      (top1_acc, top5_acc))
            return

        if args.gpu == 0:
            top1_acc = sum(sum1) / sum(cnt1)
            top5_acc = sum(sum5) / sum(cnt5)
            print('Validation: Top1: %.3f | Top5: %.3f' % (top1_acc, top5_acc))

            # save checkpoint
            acclist_val += [top1_acc]
            if top1_acc > best_pred:
                best_pred = top1_acc
                is_best = True
            state_dict = {
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            }
            if args.amp:
                state_dict['amp'] = amp.state_dict()
            encoding.utils.save_checkpoint(state_dict,
                                           args=args,
                                           is_best=is_best)

    if args.export:
        if args.gpu == 0:
            torch.save(model.module.state_dict(), args.export + '.pth')
        return

    if args.eval:
        validate(args.start_epoch)
        return

    for epoch in range(args.start_epoch, args.epochs):
        tic = time.time()
        train(epoch)
        if epoch % 10 == 0:  # or epoch == args.epochs-1:
            validate(epoch)
        elapsed = time.time() - tic
        if args.gpu == 0:
            print(f'Epoch: {epoch}, Time cost: {elapsed}')

    if args.gpu == 0:
        encoding.utils.save_checkpoint(
            {
                'epoch': args.epochs - 1,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            },
            args=args,
            is_best=False)
Пример #3
0
def main():
    args = Options().parse()
    # ngpus_per_node = torch.cuda.device_count()
    # gpu = args.gpu
    # args.rank = args.rank * ngpus_per_node + gpu
    # print('rank: {} / {}'.format(args.rank, args.world_size))
    # dist.init_process_group(backend=args.dist_backend,
    #                         init_method=args.dist_url,
    #                         world_size=args.world_size,
    #                         rank=args.rank)
    # torch.cuda.set_device(args.gpu)
    # init the args
    global best_pred, acclist_train, acclist_val

    print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    # init dataloader
    transform_train, transform_val = encoding.transforms.get_transform(
        args.dataset, args.base_size, args.crop_size, args.rand_aug)
    trainset = encoding.datasets.get_dataset(
        args.dataset,
        root=os.path.expanduser('~/.encoding/data'),
        transform=transform_train,
        train=True,
        download=True)
    valset = encoding.datasets.get_dataset(
        args.dataset,
        root=os.path.expanduser('~/.encoding/data'),
        transform=transform_val,
        train=False,
        download=True)

    # train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=False)

    # val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=False)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False)

    # init the model
    model_kwargs = {}
    if args.pretrained:
        model_kwargs['pretrained'] = True

    if args.final_drop > 0.0:
        model_kwargs['final_drop'] = args.final_drop

    if args.dropblock_prob > 0.0:
        model_kwargs['dropblock_prob'] = args.dropblock_prob
    if args.last_gamma:
        model_kwargs['last_gamma'] = True

    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg

    model = encoding.models.get_model(args.model, **model_kwargs)

    if args.dropblock_prob > 0.0:
        from functools import partial
        from encoding.nn import reset_dropblock
        nr_iters = (args.epochs - args.warmup_epochs) * len(train_loader)
        apply_drop_prob = partial(reset_dropblock,
                                  args.warmup_epochs * len(train_loader),
                                  nr_iters, 0.0, args.dropblock_prob)
        model.apply(apply_drop_prob)

    # if args.gpu == 0:
    print(model)
    # elif "triple" in args.model:
    criterion_triplet = HardTripletLoss()
    if args.mixup > 0:
        # train_loader = MixUpWrapper(args.mixup, 1000, train_loader, args.gpu)
        criterion = NLLMultiLabelSmooth(args.label_smoothing)

    elif args.label_smoothing > 0.0:
        criterion = LabelSmoothing(args.label_smoothing)

    else:
        criterion = nn.CrossEntropyLoss()
    model = torch.nn.DataParallel(model)
    # print(model)
    # model.cuda(args.gpu)
    model.cuda()
    criterion.cuda()
    # model = DistributedDataParallel(model,  device_ids=[args.gpu] ,find_unused_parameters=True)

    # criterion and optimizer
    if args.no_bn_wd:
        parameters = model.named_parameters()
        param_dict = {}
        for k, v in parameters:
            param_dict[k] = v
        bn_params = [
            v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)
        ]
        rest_params = [
            v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)
        ]
        # if args.gpu == 0:
        print(" Weight decay NOT applied to BN parameters ")
        print(
            f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}'
        )
        optimizer = torch.optim.SGD([{
            'params': bn_params,
            'weight_decay': 0
        }, {
            'params': rest_params,
            'weight_decay': args.weight_decay
        }],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    # check point
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint[
                'epoch'] + 1 if args.start_epoch == 0 else args.start_epoch
            best_pred = checkpoint['best_pred']
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise RuntimeError("=> no resume checkpoint found at '{}'". \
                               format(args.resume))
    scheduler = LR_Scheduler(args.lr_scheduler,
                             base_lr=args.lr,
                             num_epochs=args.epochs,
                             iters_per_epoch=len(train_loader),
                             warmup_epochs=args.warmup_epochs)

    def train(epoch):
        # train_sampler.set_epoch(epoch)
        model.train()
        train_loss, correct, total = 0, 0, 0
        # losses = AverageMeter()
        top1 = AverageMeter()
        global best_pred, acclist_train
        for batch_idx, (data, target) in enumerate(train_loader):
            scheduler(optimizer, batch_idx, epoch, best_pred)
            if not args.mixup:
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output, embed = model(data)
            loss1 = criterion(output, target)
            loss2 = criterion_triplet(embed, target)

            loss = loss1 + 0.1 * loss2

            loss.backward()
            optimizer.step()

            # -----另一种计算方式
            train_loss += loss.item()
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).cpu().sum()
            total += target.size(0)

            # ------end------

            if not args.mixup:
                acc1 = accuracy(output, target, topk=(1, ))
                # print("acc1:")
                # print(acc1)
                top1.update(acc1[0], data.size(0))

            # losses.update(loss.item(), data.size(0))
            if batch_idx % 100 == 0:
                if args.mixup:
                    print('Batch: %d| Loss: %.3f' % (batch_idx, train_loss /
                                                     (batch_idx + 1)))
                else:
                    print('Batch: %d| Loss: %.3f | Top1: %.3f' %
                          (batch_idx, train_loss / (batch_idx + 1), top1.avg))
        print(' Train set, Accuracy:({:.0f}%)\n'.format(100. * correct /
                                                        total))
        print(' Top1: %.3f' % top1.avg)
        acclist_train += [top1.avg]

    def validate(epoch):
        model.eval()
        top1 = AverageMeter()
        top5 = AverageMeter()
        global best_pred, acclist_train, acclist_val
        is_best = False
        correct, total = 0, 0
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.cuda(), target.cuda()
            with torch.no_grad():
                output, _ = model(data)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))

        # sum all
        # sum1, cnt1, sum5, cnt5 = torch_dist_sum( top1.sum, top1.count, top5.sum, top5.count)
        # sum1, cnt1, sum5, cnt5 = top1.sum, top1.count, top5.sum, top5.count

        # if args.eval:
        #     top1_acc = sum1 / cnt1
        #     top5_acc = sum5 / cnt5
        #     print('Validation: Top1: %.3f | Top5: %.3f' % (top1_acc, top5_acc))
        #     return

        top1_acc = top1.avg
        top5_acc = top5.avg
        print('Validation: Top1: %.3f | Top5: %.3f' %
              (100. * top1_acc, 100. * top5_acc))
        print('Valid set, Accuracy: %.3f' % (100. * top1_acc))
        print('Validation: Top1: %.3f | Top5: %.3f' % (top1_acc, top5_acc))
        # save checkpoint
        acclist_val += [top1_acc]
        if top1_acc > best_pred:
            best_pred = top1_acc
            is_best = True
        encoding.utils.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_pred': best_pred,
                'acclist_train': acclist_train,
                'acclist_val': acclist_val,
            },
            args=args,
            is_best=is_best)

    if args.export:
        torch.save(model.module.state_dict(), args.export + '.pth')
        return

    if args.eval:
        validate(args.start_epoch)
        return

    for epoch in range(args.start_epoch, args.epochs):
        tic = time.time()
        train(epoch)
        if epoch % 10 == 0:  # or epoch == args.epochs-1:
            validate(epoch)
        elapsed = time.time() - tic
        print(f'Epoch: {epoch}, Time cost: {elapsed}')

    encoding.utils.save_checkpoint(
        {
            'epoch': args.epochs - 1,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_pred': best_pred,
            'acclist_train': acclist_train,
            'acclist_val': acclist_val,
        },
        args=args,
        is_best=False)
Пример #4
0
def main_worker(args):
    args.gpu = args.local_rank
    # args.rank = args.rank * ngpus_per_node + gpu
    print('rank: {} / {}'.format(args.local_rank, dist.get_world_size()))
    # init the args
    global best_pred, acclist_train, acclist_val

    if args.gpu == 0:
        print(args)

    # init dataloader
    transform_train, transform_val = encoding.transforms.get_transform(
        args.dataset, args.base_size, args.crop_size, args.rand_aug)
    trainset = encoding.datasets.get_dataset(args.dataset,
                                             root=os.path.expanduser('~/data'),
                                             transform=transform_train,
                                             train=True,
                                             download=True)
    valset = encoding.datasets.get_dataset(args.dataset,
                                           root=os.path.expanduser('~/data'),
                                           transform=transform_val,
                                           train=False,
                                           download=True)

    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        valset, shuffle=False)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)
    print(len(train_loader), len(val_loader))
    # init the model
    model_kwargs = {}
    if args.pretrained:
        model_kwargs['pretrained'] = True

    if args.final_drop > 0.0:
        model_kwargs['final_drop'] = args.final_drop

    if args.dropblock_prob > 0.0:
        model_kwargs['dropblock_prob'] = args.dropblock_prob

    if args.last_gamma:
        model_kwargs['last_gamma'] = True

    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg

    model = encoding.models.get_model(args.model, **model_kwargs)

    if args.dropblock_prob > 0.0:
        from functools import partial
        from encoding.nn import reset_dropblock
        nr_iters = (args.epochs - args.warmup_epochs) * len(train_loader)
        apply_drop_prob = partial(reset_dropblock,
                                  args.warmup_epochs * len(train_loader),
                                  nr_iters, 0.0, args.dropblock_prob)
        model.apply(apply_drop_prob)

    if args.gpu == 0:
        print(model)

    if args.mixup > 0:
        train_loader = MixUpWrapper(args.mixup, 1000, train_loader, args.gpu)
        criterion = NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        criterion = LabelSmoothing(args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    model.cuda(args.gpu)
    criterion.cuda(args.gpu)
    model = DistributedDataParallel(model, device_ids=[args.gpu])

    # criterion and optimizer
    if args.no_bn_wd:
        parameters = model.named_parameters()
        param_dict = {}
        for k, v in parameters:
            param_dict[k] = v
        bn_params = [
            v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)
        ]
        rest_params = [
            v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)
        ]
        if args.gpu == 0:
            print(" Weight decay NOT applied to BN parameters ")
            print(
                f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}'
            )
        optimizer = torch.optim.SGD([{
            'params': bn_params,
            'weight_decay': 0
        }, {
            'params': rest_params,
            'weight_decay': args.weight_decay
        }],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # scheduler = LR_Scheduler(args.lr_scheduler,
    #                          base_lr=args.lr,
    #                          num_epochs=args.epochs,
    #                          iters_per_epoch=len(train_loader),
    #                          warmup_epochs=args.warmup_epochs)
    directory = "runs/%s/%s/%s/" % (args.dataset, args.model, args.checkname)

    runner = Runner(model,
                    batch_processor,
                    optimizer,
                    directory,
                    log_level='INFO')

    lr_config = dict(policy='cosine',
                     warmup_ratio=0.01,
                     warmup='linear',
                     warmup_iters=len(train_loader) * args.warmup_epochs)

    log_config = dict(interval=20,
                      hooks=[
                          dict(type='TextLoggerHook'),
                          dict(type='TensorboardLoggerHook')
                      ])

    runner.register_training_hooks(
        lr_config=lr_config,
        optimizer_config=dict(grad_clip=dict(max_norm=40, norm_type=2)),
        checkpoint_config=dict(interval=5),
        log_config=log_config)

    runner.register_hook(DistSamplerSeedHook())
    if args.resume is not None:
        runner.resume(args.resume)

    runner.run([train_loader, val_loader], [('train', 1), ('val', 1)],
               args.epochs,
               criterion=criterion,
               mixup=args.mixup)