Esempio n. 1
0
    def __init__(self, backbone, lr_channels=17, hr_channels=5):
        super(isCNN, self).__init__()

        self.model_name = backbone
        self.lr_backbone = EfficientNet.from_name(backbone)
        self.hr_backbone = EfficientNet.from_name(backbone)

        self.lr_final = nn.Sequential(double_conv(self.n_channels, 64), nn.Conv2d(64, lr_channels, 1))

        self.up_conv1 = up_conv(2*self.n_channels+lr_channels, 512)
        self.double_conv1 = double_conv(self.size[0], 512)
        self.up_conv2 = up_conv(512, 256)
        self.double_conv2 = double_conv(self.size[1], 256)
        self.up_conv3 = up_conv(256, 128)
        self.double_conv3 = double_conv(self.size[2], 128)
        self.up_conv4 = up_conv(128, 64)
        self.double_conv4 = double_conv(self.size[3], 64)
        self.up_conv_input = up_conv(64, 32)
        self.double_conv_input = double_conv(self.size[4], 32)
        self.hr_final = nn.Conv2d(self.size[5], hr_channels, kernel_size=1)
def EfficientNet_B8(pretrained=True, num_class=5, onehot=1, onehot2=0):
    if pretrained:
        model = EfficientNet.from_pretrained('efficientnet-b8', num_classes=num_class, onehot=onehot, onehot2=onehot2)
        for name, param in model.named_parameters():
            if 'fc' not in name:
                param.requires_grad = False
    else:
        model = EfficientNet.from_name('efficientnet-b8', onehot=onehot, onehot2=onehot2)
    model.name = "EfficientNet_B8"    
    print("EfficientNet B7 Loaded!")

    return model
def EfficientNet_B6(pretrained=True, num_class=5, advprop=False, onehot=1, onehot2=0):
    if pretrained:
        model = EfficientNet.from_pretrained('efficientnet-b6', num_classes=num_class, onehot=onehot, onehot2=onehot2)
        for name, param in model.named_parameters():
            if 'fc' not in name :# and 'blocks.24' not in name and 'blocks.25' not in name
                param.requires_grad = False
    else:
        model = EfficientNet.from_name('efficientnet-b6', onehot=onehot, onehot2=onehot2)
    
    model.name = "EfficientNet_B6"
    print("EfficientNet B6 Loaded!")

    return model
Esempio n. 4
0
def main(args):
    # Step 1: parse args config
    logging.basicConfig(
        format=
        '[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(args.log_file, mode='w'),
            logging.StreamHandler()
        ])
    print_args(args)

    # Step 2: model, criterion, optimizer, scheduler
    # model = MobileNetV3(mode='large').to(args.device)

    model = EfficientNet.from_name(args.arch).to(args.device)
    # auxiliarynet = AuxiliaryNet().to(args.device)
    auxiliarynet = None

    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['model'])

    # step 3: data
    # argumetion
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # mpiidataset = MPIIDatasets(args.dataroot, train=True, transforms=transform)
    # train_dataset = GazeCaptureDatasets(args.dataroot, train=True, transforms=transform)

    # mpii_val_dataset = MPIIDatasets(args.val_dataroot, train=False, transforms=transform)
    val_dataset = GazeCaptureDatasets(args.val_dataroot,
                                      train=True,
                                      transforms=transform)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.val_batchsize,
                                shuffle=False,
                                num_workers=args.workers)

    # step 4: run
    val_loss, val_error = validate(args, val_dataloader, model, auxiliarynet,
                                   1)
    print("val_loss: '{}' val_error: '{}'".format(val_loss, val_error))
Esempio n. 5
0
    def __init__(self,
                 num_classes,
                 network='efficientdet-d0',
                 D_bifpn=3,
                 W_bifpn=88,
                 D_class=3,
                 is_training=True,
                 threshold=0.5,
                 iou_threshold=0.5):
        super(EfficientDet, self).__init__()
        # self.efficientnet = EfficientNet.from_pretrained(MODEL_MAP[network])
        self.efficientnet = EfficientNet.from_name(
            MODEL_MAP[network], override_params={'num_classes': num_classes})
        self.is_training = is_training
        self.BIFPN = BIFPN(
            in_channels=self.efficientnet.get_list_features()[-5:],
            out_channels=W_bifpn,
            stack=D_bifpn,
            num_outs=5)
        self.regressionModel = RegressionModel(W_bifpn)
        self.classificationModel = ClassificationModel(W_bifpn,
                                                       num_classes=num_classes)
        self.anchors = Anchors()
        self.regressBoxes = BBoxTransform()
        self.clipBoxes = ClipBoxes()
        self.threshold = threshold
        self.iou_threshold = iou_threshold

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        prior = 0.01

        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log(
            (1.0 - prior) / prior))
        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)
        self.freeze_bn()
Esempio n. 6
0
    def __init__(self, backbone, out_channels=2, concat_input=True):
        super().__init__()

        self.model_name = backbone
        self.backbone = EfficientNet.from_name(backbone)
        self.concat_input = concat_input

        self.up_conv1 = up_conv(self.n_channels, 512)
        self.double_conv1 = double_conv(self.size[0], 512)
        self.up_conv2 = up_conv(512, 256)
        self.double_conv2 = double_conv(self.size[1], 256)
        self.up_conv3 = up_conv(256, 128)
        self.double_conv3 = double_conv(self.size[2], 128)
        self.up_conv4 = up_conv(128, 64)
        self.double_conv4 = double_conv(self.size[3], 64)

        if self.concat_input:
            self.up_conv_input = up_conv(64, 32)
            self.double_conv_input = double_conv(self.size[4], 32)

        self.final_conv = nn.Conv2d(self.size[5], out_channels, kernel_size=1)
Esempio n. 7
0
    def __init__(self,
                 num_classes,
                 network='efficientdet-d0',
                 D_bifpn=3,
                 W_bifpn=88,
                 D_class=3,
                 is_training=True,
                 threshold=0.01,
                 iou_threshold=0.5):
        super(EfficientDet, self).__init__()
        try:
            self.backbone = EfficientNet.from_pretrained(MODEL_MAP[network])
        except:
            print("pretrained model is not available ", MODEL_MAP[network])
            print("make backobne from name")
            self.backbone = EfficientNet.from_name(MODEL_MAP[network])
        self.is_training = is_training
        self.neck = BIFPN(in_channels=self.backbone.get_list_features()[-5:],
                          out_channels=W_bifpn,
                          stack=D_bifpn,
                          num_outs=5)
        self.bbox_head = RetinaHead(num_classes=num_classes,
                                    in_channels=W_bifpn)

        self.anchors = Anchors()
        self.regressBoxes = BBoxTransform()
        self.clipBoxes = ClipBoxes()
        self.threshold = threshold
        self.iou_threshold = iou_threshold
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        self.freeze_bn()
        self.criterion = FocalLoss()
Esempio n. 8
0
    def __init__(self):
        super(EfficientBase, self).__init__()
        self.model_name = 'efficientnet-b0'
        self.encoder = EfficientNet.from_name('efficientnet-b0')
        self.size = [1280, 80, 40, 24, 16]

        # sum
        self.decoder5 = self._make_deconv_layer(DecoderBlock,
                                                self.size[0],
                                                self.size[1],
                                                stride=2)
        self.decoder4 = self._make_deconv_layer(DecoderBlock,
                                                self.size[1],
                                                self.size[2],
                                                stride=2)
        self.decoder3 = self._make_deconv_layer(DecoderBlock,
                                                self.size[2],
                                                self.size[3],
                                                stride=2)
        self.decoder2 = self._make_deconv_layer(DecoderBlock,
                                                self.size[3],
                                                self.size[4],
                                                stride=2)
        self.decoder1 = self._make_deconv_layer(DecoderBlock, self.size[4],
                                                self.size[4])

        self.final = nn.Sequential(
            nn.ConvTranspose2d(16,
                               8,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(8, 2, kernel_size=1))
Esempio n. 9
0
def test_model(args):
    # create model
    num_classes = 2
    if args.arch == 'efficientnet_b0':
        if args.pretrained:
            model = EfficientNet.from_pretrained("efficientnet-b0",
                                                 quantize=args.quantize,
                                                 num_classes=num_classes)
        else:
            model = EfficientNet.from_name(
                "efficientnet-b0",
                quantize=args.quantize,
                override_params={'num_classes': num_classes})
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'mobilenet_v1':
        model = mobilenet_v1(quantize=args.quantize, num_classes=num_classes)
        model = torch.nn.DataParallel(model).cuda()

        if args.pretrained:
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']

            if num_classes != 1000:
                new_dict = {
                    k: v
                    for k, v in state_dict.items() if 'fc' not in k
                }
                state_dict = new_dict

            res = model.load_state_dict(state_dict, strict=False)

            for missing_key in res.missing_keys:
                assert 'quantize' in missing_key or 'fc' in missing_key

    elif args.arch == 'mobilenet_v2':
        model = mobilenet_v2(pretrained=args.pretrained,
                             num_classes=num_classes,
                             quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet18':
        model = resnet18(pretrained=args.pretrained,
                         num_classes=num_classes,
                         quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet50':
        model = resnet50(pretrained=args.pretrained,
                         num_classes=num_classes,
                         quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet152':
        model = resnet152(pretrained=args.pretrained,
                          num_classes=num_classes,
                          quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'resnet164':
        model = resnet_164(num_classes=num_classes, quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'vgg11':
        model = vgg11(pretrained=args.pretrained,
                      num_classes=num_classes,
                      quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    elif args.arch == 'vgg19':
        model = vgg19(pretrained=args.pretrained,
                      num_classes=num_classes,
                      quantize=args.quantize)
        model = torch.nn.DataParallel(model).cuda()

    else:
        logging.info('No such model.')
        sys.exit()

    if args.resume and not args.pretrained:
        if os.path.isfile(args.resume):
            logging.info('=> loading checkpoint `{}`'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
                args.resume, checkpoint['epoch']))
        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))

    cudnn.benchmark = False
    test_loader = prepare_test_data(dataset=args.dataset,
                                    datadir=args.datadir,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.workers)
    criterion = nn.CrossEntropyLoss().cuda()

    with torch.no_grad():
        prec1 = validate(args, test_loader, model, criterion, 0)
Esempio n. 10
0
    return total_flops


if __name__ == '__main__':
    if 'cifar10' in args.dataset:
        num_classes = 10
        input_res = 32
    elif 'cifar100' in args.dataset:
        num_classes = 100
        input_res = 32
    elif 'imagenet' in args.dataset:
        num_classes = 1000
        input_res = 224

    if args.arch == 'efficientnet_b0':
        model = EfficientNet.from_name(
            "efficientnet-b0", override_params={'num_classes': num_classes})

    elif args.arch == 'mobilenet_v1':
        model = mobilenet_v1(num_classes=num_classes)

    elif args.arch == 'mobilenet_v2':
        model = mobilenet_v2(num_classes=num_classes)

    elif args.arch == 'resnet18':
        model = resnet18(num_classes=num_classes)

    elif args.arch == 'resnet50':
        model = resnet50(num_classes=num_classes)

    elif args.arch == 'resnet152':
        model = resnet152(num_classes=num_classes)
Esempio n. 11
0
import cv2
import numpy as np
import torchvision.transforms as tfs
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from models.efficientnet import EfficientNet
from models.efficientdet import EfficientDet
from models.loss import FocalLoss

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
#model = EfficientNet.from_pretrained('efficientnet-b0',False,True)
parameters = torch.load('weights/efficientnet-b0.pth')
model = EfficientNet.from_name('efficientnet-b0')
parameters = [v for _, v in parameters.items()]
model_state_dict = model.state_dict()
for i, (k, v) in enumerate(model_state_dict.items()):
    model_state_dict[k] = parameters[i]
torch.save(model_state_dict, 'weights/efficientnet-b0.pth')
'''
data_dir = 'data'
imgs_path = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]

imgs = [cv2.resize((cv2.imread(i)[...,::-1]/255 - mean)/std,(608,608)) for i in imgs_path]
imgs = torch.stack([torch.from_numpy(i.astype(np.float32)) for i in imgs], 0).permute(0, 3, 1, 2)

imgs = imgs[:4].cuda()
features = model.extract_features(imgs)
print(features.size())
Esempio n. 12
0
def main(args):
    # Step 1: parse args config
    logging.basicConfig(
        format=
        '[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(args.log_file, mode='w'),
            logging.StreamHandler()
        ])
    print_args(args)

    # Step 2: model, criterion, optimizer, scheduler
    # model = MobileNetV3(mode='large').to(args.device)

    if args.pretrained:
        model = EfficientNet.from_pretrained(args.arch).to(args.device)
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        print("=> creating model '{}'".format(args.arch))
        model = EfficientNet.from_name(args.arch).to(args.device)
    auxiliarynet = AuxiliaryNet().to(args.device)
    # auxiliarynet = AuxiliaryNet()
    criterion = GazeLoss()
    optimizer = torch.optim.Adam([{
        'params': model.parameters()
    }, {
        'params': auxiliarynet.parameters()
    }],
                                 lr=args.base_lr,
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        patience=args.lr_patience,
        verbose=True,
        min_lr=args.min_lr)

    # optionally resume from a checkpoint
    min_error = 1e6
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            min_error = checkpoint['error']
            model.load_state_dict(checkpoint['model'])
            auxiliarynet.load_state_dict(checkpoint['aux'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {}) {:.3f}".format(
                args.resume, checkpoint['epoch'], min_error))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # step 3: data
    # argumetion
    # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # train_dataset = MPIIDatasets(args.dataroot, train=True, transforms=transform)
    train_dataset = GazeCaptureDatasets(args.dataroot,
                                        train=True,
                                        transforms=transform)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.train_batchsize,
                                  shuffle=True,
                                  num_workers=args.workers,
                                  drop_last=True)

    # val_dataset = MPIIDatasets(args.val_dataroot, train=False, transforms=transform)
    val_dataset = GazeCaptureDatasets(args.val_dataroot,
                                      train=False,
                                      transforms=transform)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.val_batchsize,
                                shuffle=False,
                                num_workers=args.workers)

    # step 4: run
    writer = SummaryWriter(args.tensorboard)
    for epoch in range(args.start_epoch, args.end_epoch + 1):
        train_loss, train_error = train(args, train_dataloader, model,
                                        auxiliarynet, criterion, optimizer,
                                        epoch)

        val_loss, val_error = validate(args, val_dataloader, model,
                                       auxiliarynet, criterion, epoch)
        filename = os.path.join(str(args.snapshot),
                                "checkpoint_epoch_" + str(epoch) + '.pth.tar')
        is_best = min_error > val_error
        min_error = min(min_error, val_error)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'aux': auxiliarynet.state_dict(),
                'optimizer': optimizer.state_dict(),
                'error': min_error,
            }, is_best, filename)
        scheduler.step(val_loss)
        writer.add_scalars('data/error', {
            'val error': val_error,
            'train error ': train_error
        }, epoch)
        writer.add_scalars('data/loss', {
            'val loss': val_loss,
            'train loss': train_loss
        }, epoch)
    writer.close()
Esempio n. 13
0
def main():
    args = parser.parse_args()
    log = logger(args)
    log.write('V' * 50 + " configs " + 'V' * 50 + '\n')
    log.write(args)
    log.write('')
    log.write('Λ' * 50 + " configs " + 'Λ' * 50 + '\n')

    # load data
    input_size = (224, 224)
    dataset = DataLoader(args, input_size)
    train_data, val_data = dataset.load_data()

    num_classes = dataset.num_classes
    classes = dataset.classes
    log.write('\n\n')
    log.write('V' * 50 + " data " + 'V' * 50 + '\n')
    log.info('success load data.')
    log.info('num classes: %s' % num_classes)
    log.info('classes: ' + str(classes) + '\n')
    log.write('Λ' * 50 + " data " + 'Λ' * 50 + '\n')

    # Random seed
    if args.manual_seed is None:
        args.manual_seed = random.randint(1, 10000)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    log.write('random seed is %s' % args.manual_seed)

    # pretrained or not
    log.write('\n\n')
    log.write('V' * 50 + " model " + 'V' * 50 + '\n')
    if args.pretrained:
        log.info("using pre-trained model")
    else:
        log.info("creating model from initial")

    # model
    log.info('using model: %s' % args.arch)
    log.write('')
    log.write('Λ' * 50 + " model " + 'Λ' * 50 + '\n')

    # resume model
    if args.resume:
        log.info('using resume model: %s' % args.resume)
        states = torch.load(args.resume)
        model = states['model']
        model.load_state_dict(states['state_dict'])
    else:
        log.info('not using resume model')
        if args.arch.startswith('dla'):
            model = eval(args.arch)(args.pretrained, num_classes)

        elif args.arch.startswith('efficientnet'):
            if args.pretrained:
                model = EfficientNet.from_pretrained(args.arch,
                                                     num_classes=num_classes)
            else:
                model = EfficientNet.from_name(args.arch,
                                               num_classes=num_classes)
        else:
            model = make_model(model_name=args.arch,
                               num_classes=num_classes,
                               pretrained=args.pretrained,
                               pool=nn.AdaptiveAvgPool2d(output_size=1),
                               classifier_factory=None,
                               input_size=input_size,
                               original_model_state_dict=None,
                               catch_output_size_exception=True)

    # cuda
    have_cuda = torch.cuda.is_available()
    use_cuda = args.use_gpu and have_cuda
    log.info('using cuda: %s' % use_cuda)
    if have_cuda and not use_cuda:
        log.info(
            '\nWARNING: found gpu but not use, you can switch it on by: -ug or --use-gpu\n'
        )

    multi_gpus = False
    if use_cuda:
        torch.backends.cudnn.benchmark = True
        if args.multi_gpus:
            gpus = torch.cuda.device_count()
            multi_gpus = gpus > 1

    if multi_gpus:
        log.info('using multi gpus, found %d gpus.' % gpus)
        model = torch.nn.DataParallel(model).cuda()
    elif use_cuda:
        model = model.cuda()

    # criterian
    log.write('\n\n')
    log.write('V' * 50 + " criterion " + 'V' * 50 + '\n')
    if args.label_smoothing > 0 and args.mixup == 1:
        criterion = CrossEntropyWithLabelSmoothing()
        log.info('using label smoothing criterion')

    elif args.label_smoothing > 0 and args.mixup < 1:
        criterion = CrossEntropyWithMixup()
        log.info('using label smoothing and mixup criterion')

    elif args.mixup < 1 and not args.label_smoothing == 0:
        criterion = CrossEntropyWithMixup()
        log.info('using mixup criterion')

    else:
        criterion = nn.CrossEntropyLoss()
        log.info('using normal cross entropy criterion')

    if use_cuda:
        criterion = criterion.cuda()

    log.write('using criterion: %s' % str(criterion))
    log.write('')
    log.write('Λ' * 50 + " criterion " + 'Λ' * 50 + '\n')
    # optimizer
    log.write('\n\n')
    log.write('V' * 50 + " optimizer " + 'V' * 50 + '\n')
    if args.linear_scaling:
        args.lr = 0.1 * args.train_batch / 256
    log.write('initial lr: %4f\n' % args.lr)
    if args.no_bias_decay:
        log.info('using no bias weight decay')
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = optim.SGD(optimizer_grouped_parameters,
                              lr=args.lr,
                              momentum=args.momentum)

    else:
        log.info('using bias weight decay')
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    if args.resume:
        optimizer.load_state_dict(states['optimizer'])
    log.write('using optimizer: %s' % str(optimizer))
    log.write('')
    log.write('Λ' * 50 + " optimizer " + 'Λ' * 50 + '\n')
    # low precision
    use_low_precision_training = args.low_precision_training
    if use_low_precision_training:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # lr scheduler
    iters_per_epoch = int(np.ceil(len(train_data) / args.train_batch))
    total_iters = iters_per_epoch * args.epochs
    log.write('\n\n')
    log.write('V' * 50 + " lr_scheduler " + 'V' * 50 + '\n')
    if args.warmup:
        log.info('using warmup scheduler, warmup epochs: %d' %
                 args.warmup_epochs)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, iters_per_epoch * args.warmup_epochs, eta_min=1e-6)
    elif args.cosine:
        log.info('using cosine lr scheduler')
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=total_iters)

    else:
        log.info('using normal lr decay scheduler')
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         factor=0.5,
                                                         patience=10,
                                                         min_lr=1e-6,
                                                         mode='min')

    log.write('using lr scheduler: %s' % str(scheduler))
    log.write('')
    log.write('Λ' * 50 + " lr_scheduler " + 'Λ' * 50 + '\n')

    log.write('\n\n')
    log.write('V' * 50 + " training start " + 'V' * 50 + '\n')
    best_acc = 0
    start = time.time()
    log.info('\nstart training ...')
    for epoch in range(1, args.epochs + 1):
        lr = optimizer.param_groups[-1]['lr']
        train_loss, train_acc = train_one_epoch(
            log, scheduler, train_data, model, criterion, optimizer, use_cuda,
            use_low_precision_training, args.label_smoothing, args.mixup)
        test_loss, test_acc = val_one_epoch(log, val_data, model, criterion,
                                            use_cuda)
        end = time.time()
        log.info(
            'epoch: [%d / %d], time spent(s): %.2f, mean time: %.2f, lr: %.4f, train loss: %.4f, train acc: %.4f, '
            'test loss: %.4f, test acc: %.4f' %
            (epoch, args.epochs, end - start, (end - start) / epoch, lr,
             train_loss, train_acc, test_loss, test_acc))
        states = dict()
        states['arch'] = args.arch
        if multi_gpus:
            states['model'] = model.module
            states['state_dict'] = model.module.state_dict()
        else:
            states['model'] = model
            states['state_dict'] = model.state_dict()
        states['optimizer'] = optimizer.state_dict()
        states['test_acc'] = test_acc
        states['train_acc'] = train_acc
        states['epoch'] = epoch
        states['classes'] = classes
        is_best = False
        if test_acc > best_acc:
            is_best = True
            log.save_checkpoint(states, is_best)
        else:
            log.save_checkpoint(states, is_best)

    log.write('\ntraining finished.')
    log.write('Λ' * 50 + " training finished " + 'Λ' * 50 + '\n')
    log.log_file.close()
    log.writer.close()
Esempio n. 14
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # load teacher
    print("==> loading teacher '{}'".format(args.teacher))
    teacher = model_dict[args.t_arch]()
    for name, param in teacher.named_parameters():
        param.requires_grad = False

    checkpoint = torch.load(args.teacher, map_location="cpu")
    # rename moco pre-trained keys
    state_dict = checkpoint['state_dict']
    for k in list(state_dict.keys()):
        # retain only encoder_q up to before the embedding layer
        if k.startswith('module.encoder_q'
                        ) and not k.startswith('module.encoder_q.fc'):
            # remove prefix
            state_dict[k[len("module.encoder_q."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]

    args.start_epoch = 0
    msg = teacher.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    print("==> done")

    # create model
    print("==> creating model '{}'".format(args.arch))
    if not args.arch in [
            'efficientnet-b0', 'efficientnet-b1', 'mobilenetv3-large'
    ]:
        model = Distiller(base_encoder=model_dict[args.arch],
                          teacher=teacher,
                          h_dim=args.pred_h_dim,
                          args=args)

    elif 'efficientnet' in args.arch:
        model = Distiller(base_encoder=EfficientNet.from_name(args.arch),
                          teacher=teacher,
                          h_dim=args.pred_h_dim,
                          args=args)
    else:
        # mobile net v3 - large
        model = Distiller(base_encoder=mobilenetv3_large(),
                          teacher=teacher,
                          h_dim=args.pred_h_dim,
                          args=args)
    print("==> done")

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply(
            [
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ],
            p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]

    train_dataset = datasets.ImageFolder(
        traindir,
        moco.loader.TwoCropsTransform(transforms.Compose(augmentation)))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    logger = SummaryWriter(logdir=args.tb_folder)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, optimizer, epoch + 1, args, logger)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if (epoch + 1) % args.save_freq == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    filename=os.path.join(
                        args.save_folder,
                        'checkpoint_{:04d}epoch.pth.tar'.format(epoch + 1)))

            # for resume
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                filename=os.path.join(args.save_folder, 'latest.pth.tar'))