예제 #1
0
def main():
    device = torch.device('cpu')
    model = EfficientNet(1.0, 1.0)
    weights = torch.load('efficientnet_torch.pth',
                         map_location=torch.device('cpu'))
    model.load_state_dict(weights)
    model.eval()

    config = CONFIG()

    optimizer = optim.SGD(model.parameters(),
                          lr=LEARNING_RATE,
                          momentum=config.momentum,
                          weight_decay=config.weight_decay)

    frame_rate_calc = 1
    freq = getTickFrequency()

    image_path = []

    writing = 0
    phoneWithHand = 0
    others = 0
    sleep = 0

    image_path.append("./others/")
    image_path.append("./phoneWithHand/")
    image_path.append("./sleep/")
    image_path.append("./writing/")

    for j in range(0, 4):
        sumsum = 0
        others = 0
        phoneWithHand = 0
        sleep = 0
        writing = 0
        for i in range(0, 100):
            image = Image.open(image_path[j] + str(i) + '.jpg')
            trans = torchvision.transforms.ToTensor()
            image = trans(image)
            data = image.unsqueeze(0)
            t1 = getTickCount()
            scores = model(data)
            t2 = getTickCount()
            time1 = (t2 - t1) / freq
            frame_rate_calc = 1 / time1
            sumsum += frame_rate_calc
            pred = scores.data.max(1)[1]

        sumsum /= 100
        print('others: ', end='')
        print(others)
        print('phoneWithHand: ', end='')
        print(phoneWithHand)
        print('writing: ', end='')
        print(writing)
        print('sleep: ', end='')
        print(sleep)
        print('Average FPS: ', end='')
        print(sumsum)
예제 #2
0
def main():
    global best_val_score
    gl = get_grid_list()
    for g in gl:
        best_val_score = 0.0
        conf = EfficientNetConfig(depth=g[0], width=g[1], resolution=g[2], num_classes=10)
        run_name = conf_to_name(conf)

        run = wandb.init(project='EfficientNet_small', reinit=True)
        run.name = run_name
        run.save()

        cur_model = EfficientNet(conf)
        cur_model.cuda()

        wandb.watch(cur_model)

        train(cur_model, run_name)
        print(run_name, best_val_score)
        run.finish()
예제 #3
0
def model_train(fold: int) -> None:
    # Prepare Data
    df = pd.read_csv(os.path.join(config.save_dir, 'split_kfold.csv'))
    df_train = df[df['kfold'] != fold].reset_index(drop=True)
    df_val = df[df['kfold'] == fold].reset_index(drop=True)

    df_train.drop(['kfold'], axis=1).to_csv(os.path.join(
        config.save_dir, f'train-kfold-{fold}.csv'), index=False)
    df_val.drop(['kfold'], axis=1).to_csv(os.path.join(
        config.save_dir, f'val-kfold-{fold}.csv'), index=False)

    train_dataset = MnistDataset(os.path.join(config.data_dir, 'train'), os.path.join(
        config.save_dir, f'train-kfold-{fold}.csv'), transforms_train)
    val_dataset = MnistDataset(
        os.path.join(config.data_dir, 'train'), os.path.join(config.save_dir, f'val-kfold-{fold}.csv'), transforms_test)

    model = MnistModel(EfficientNet())
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=os.path.join(save_dir, f'{fold}'),
        filename='{epoch:02d}-{val_loss:.2f}.pth',
        save_top_k=5,
        mode='min',
    )
    early_stopping = EarlyStopping(
        monitor='val_loss',
        mode='min',
    )

    if config.device == 'tpu':
        train_loader = DataLoader(train_dataset, batch_size=16, num_workers=10, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=2, num_workers=10, shuffle=False)
        trainer = Trainer(
            tpu_cores=8, 
            num_sanity_val_steps=-1,
            deterministic=True, 
            max_epochs=config.epochs, 
            callbacks=[checkpoint_callback, early_stopping]
        )
    else:
        train_loader = DataLoader(train_dataset, batch_size=16, num_workers=10, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=8, num_workers=10, shuffle=False)
        trainer = Trainer(
            gpus=1, 
            num_sanity_val_steps=-1,
            deterministic=True, 
            max_epochs=config.epochs, 
            callbacks=[checkpoint_callback, early_stopping]
        )

    trainer.fit(model, train_loader, val_loader)
    def __init__(self, args):
        super().__init__()

        self.inp = 64
        self.oup = 64
        self.bifpn_repeat = 2
        print(args.backbone)
        self.backbone = EfficientNet.from_pretrained(args)
        # self.backbone.get_list_features()
        self.tail = nn.ModuleList([
            ConvBlock(320, self.oup, 3, 2, 1),
            ConvBlock(self.oup, self.oup, 3, 2, 1)
        ])
        self.channel_same = self.change_channel(
            self.backbone.get_list_feature()[-3:])
        self.BiFPN_first = BiFPN(oup=self.oup, first=True)
        self.BiFPN = nn.ModuleList()
        for i in range(self.bifpn_repeat - 1):
            self.BiFPN.append(BiFPN(oup=self.oup, first=False))
예제 #5
0
        
if __name__ == '__main__':

    import argparse
#    from efficientnet_pytorch import EfficientNet

    parser = argparse.ArgumentParser(
        description='Convert TF model to PyTorch model and save for easier future loading')
    parser.add_argument('--model_name', type=str, default='efficientnet-b0',
                        help='efficientnet-b{N}, where N is an integer 0 <= N <= 7')
    parser.add_argument('--pth_file', type=str, default='efficientnet-b0.pth',
                        help='input PyTorch model file name')
    args = parser.parse_args()

    # Build model
    model = EfficientNet.from_name(args.model_name)
                                
    pretrained_weights = torch.load(args.pth_file)
    #    model.load_state_dict(pretrained_weights)#error,key mismatched
#    print(type(pretrained_weights),dir(pretrained_weights))#<class 'collections.OrderedDict'>
#    for key in pretrained_weights.keys():
#        print(key)
    del pretrained_weights['_fc.weight']#delete unuseful weights
    del pretrained_weights['_fc.bias']
#    for key in pretrained_weights.keys():
#        print(key)
    model.load_state_dict(pretrained_weights)
    
#    from torchsummary import summary
#    summary(model.cuda(), input_size=(3, 320, 320))
    
예제 #6
0
                                        transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=2, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

device = torch.device("cuda:0")
model = efficientnet_pytorch.EfficientNet.from_pretrained('efficientnet-b0')
BPnet = BPnetwork5()
mymodel = EfficientNet.from_pretrained('efficientnet-b0')
myBPnet = BPnetwork5q()
model.to(device)
BPnet.to(device)
mymodel.to(device)
myBPnet.to(device)

# myBPnet.fc1.weight = BPnet.fc1.weight.detach().clone()
# myBPnet.fc1.bias = BPnet.fc1.bias.detach().clone()
myBPnet.fc1.load_state_dict(BPnet.fc1.state_dict())

mycriterion = CrossEntropy()
criterion = nn.CrossEntropyLoss()

optimizer_cnn = optim.RMSprop(model.parameters(), lr=0.0000256, momentum=0.9)
myoptimizer_cnn = optim.RMSprop(mymodel.parameters(),
예제 #7
0
def main():
    """
    --------------------------------------------- MAIN --------------------------------------------------------

    Instantiates the model plus loss function and defines the dataloaders for several datasets including some
    data augmentation.
    Defines the grid for a grid search on lambda_max_divrs and initial_centroid_value_multipliers which both
    have a big influence on the sparsity (and respectively accuracy) of the resulting ternary networks.
    Starts grid search.
    """

    # Manual seed for reproducibility
    torch.manual_seed(363636)

    # Global instances
    global args, use_cuda, device
    # Instantiating the parser
    args = parser.parse_args()
    # Global CUDA flag
    use_cuda = args.cuda and torch.cuda.is_available()
    # Defining device and device's map locationo
    device = torch.device("cuda" if use_cuda else "cpu")
    print('chosen device: ', device)

    # Building the model
    if args.model == 'cifar_micronet':
        print('Building MicroNet for CIFAR-100 with depth multiplier {} and width multiplier {} ...'.format(
            args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi))
        model = micronet(args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi)

    elif args.model == 'imagenet_micronet':
        print('Building MicroNet for ImageNet with depth multiplier {} and width multiplier {} ...'.format(
            args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi))
        model = image_micronet(args.dw_multps[0] ** args.phi, args.dw_multps[1] ** args.phi)

    elif args.model == 'efficientnet-b1':
        print('Building EfficientNet-B1 ...')
        model = EfficientNet.efficientnet_b1()

    elif args.model == 'efficientnet-b2':
        print('Building EfficientNet-B2 ...')
        model = EfficientNet.efficientnet_b2()

    elif args.model == 'efficientnet-b3':
        print('Building EfficientNet-B3 ...')
        model = EfficientNet.efficientnet_b3()

    elif args.model == 'efficientnet-b4':
        print('Building EfficientNet-B4 ...')
        model = EfficientNet.efficientnet_b4()

    for name, param in model.named_parameters():
        print('\n', name)

    # Transfers model to device (GPU/CPU).
    model.to(device)

    # Defining loss function and printing CUDA information (if available)
    if use_cuda:
        print("PyTorch version: ")
        print(torch.__version__)
        print("CUDA Version: ")
        print(torch.version.cuda)
        print("cuDNN version is: ")
        print(cudnn.version())
        cudnn.benchmark = True
        loss_fct = nn.CrossEntropyLoss().cuda()
    else:
        loss_fct = nn.CrossEntropyLoss()

    # Dataloaders for CIFAR, ImageNet and MNIST
    if args.dataset == 'CIFAR100':

        print('Loading CIFAR-100 data ...')
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]), download=True),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.val_batch_size, shuffle=False, **kwargs)

    elif args.dataset == 'ImageNet':

        print('Loading ImageNet data ...')
        traindir = os.path.join(args.data_path, 'train')
        valdir = os.path.join(args.data_path, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True)

        if model.__class__.__name__ == 'EfficientNet' or 'efficientnet' in str(args.model):
            image_size = EfficientNet.get_image_size(args.model)

        else:
            image_size = args.image_size

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]))
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.val_batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

    elif args.dataset == 'MNIST':

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path, train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
            batch_size=args.val_batch_size, shuffle=True, **kwargs)

    elif args.dataset == 'CIFAR10':

        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {'num_workers': args.workers, 'pin_memory': True} if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path, train=True, transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ]), download=True),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path, train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.val_batch_size, shuffle=False, **kwargs)

    else:
        raise NotImplementedError('Undefined dataset name %s' % args.dataset)


    # Gridsearch on dividers for lambda_max and initial cluster center values
    for initial_c_divr in args.ini_c_divrs:
        for lambda_max_divr in args.lambda_max_divrs:
            print('lambda_max_divr: {}, initial_c_divr: {}'.format(lambda_max_divr, initial_c_divr))
            logfile = open('./model_quantization/logfiles/logfile.txt', 'a+')
            logfile.write('lambda_max_divr: {}, initial_c_divr: {}'.format(lambda_max_divr, initial_c_divr))
            grid_search(train_loader, val_loader, model, loss_fct, lambda_max_divr, initial_c_divr)
예제 #8
0
    return pred + 1


def show_result(pred, image):
    """
  Show input image with a result of the test function
  """
    fig = plt.figure()
    title = 'predict: ' + str(pred)
    fig1 = fig.add_subplot(1, 1, 1)
    fig1.set_title(title)
    fig1.axis("off")
    plt.imshow(image)
    fig.tight_layout()
    plt.show()


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_name = 'efficientnet-b0'
image_size = EfficientNet.get_image_size(model_name)
model = EfficientNet.from_pretrained(model_name, num_classes=15)

model = model.to(device)
model.load_state_dict(torch.load('best.pt', map_location=device))

image = Image.open('sample.png').convert(
    'RGB')  # directory setting. Change sample.png to path of an input image
pred = test(image)
show_result(pred, image)
예제 #9
0
def BackBone_Unet(backbone_name):
    up_parm_dict = {
        'resnet18': [512, 256, 128, 64, 64, 64, 64, 3],
        'resnet34': [512, 256, 128, 64, 64, 64, 64, 3],
        'resnet50': [2048, 1024, 512, 256, 128, 64, 64, 3],
        'resnet101': [2048, 1024, 512, 256, 128, 64, 64, 3],
        'resnet152': [2048, 1024, 512, 256, 128, 64, 64, 3],
        'densenet121': [1024, 1024, 512, 256, 128, 64, 64, 3],
        'densenet161': [2204, 2104, 752, 352, 128, 64, 64, 3],
        'densenet201': [1920, 1792, 512, 256, 128, 64, 64, 3],
        'densenet169': [1664, 1280, 512, 256, 128, 64, 64, 3],
        'efficientnet-b0': [1280, 112, 40, 24, 16, 16, 64, 3],
        'efficientnet-b1': [1280, 112, 40, 24, 16, 16, 64, 3],
        'efficientnet-b2': [1280, 120, 48, 24, 16, 16, 64, 3],
        'efficientnet-b3': [1280, 136, 48, 32, 24, 24, 64, 3],
        'efficientnet-b4': [1280, 160, 56, 32, 24, 24, 64, 3],
        'efficientnet-b5': [1280, 176, 64, 40, 24, 24, 64, 3],
        'efficientnet-b6': [1280, 200, 72, 40, 32, 32, 64, 3],
        'efficientnet-b7': [1280, 224, 80, 48, 32, 32, 64, 3]
    }

    efficient_param = {
        # 'efficientnet type': (width_coef, depth_coef, resolution, dropout_rate)
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b1': (1.0, 1.1, 224, 0.2),
        'efficientnet-b2': (1.1, 1.2, 224, 0.3),
        'efficientnet-b3': (1.2, 1.4, 224, 0.3),
        'efficientnet-b4': (1.4, 1.8, 224, 0.4),
        'efficientnet-b5': (1.6, 2.2, 224, 0.4),
        'efficientnet-b6': (1.8, 2.6, 224, 0.5),
        'efficientnet-b7': (2.0, 3.1, 224, 0.5)
    }

    if backbone_name[0] == 'r':
        if backbone_name[-2:] == '18':
            model = ResNet.ResNet18()
        if backbone_name[-2:] == '34':
            model = ResNet.ResNet34()
        if backbone_name[-2:] == '50':
            model = ResNet.ResNet50()
        if backbone_name[-2:] == '01':
            model = ResNet.ResNet101()
        if backbone_name[-2:] == '52':
            model = ResNet.ResNet152()

        net = Res_Unet(model=model, up_parm=up_parm_dict[backbone_name])

    elif backbone_name[0] == 'd':
        if backbone_name[-2:] == '21':
            model = DenseNet.DenseNet121(seon=False)
        if backbone_name[-2:] == '61':
            model = DenseNet.DenseNet161(seon=False)
        if backbone_name[-2:] == '01':
            model = DenseNet.DenseNet201(seon=False)
        if backbone_name[-2:] == '69':
            model = DenseNet.DenseNet169(seon=False)

        net = Dense_Unet(model=model, up_parm=up_parm_dict[backbone_name])
    elif backbone_name[0] == 'e':
        param = efficient_param[backbone_name]
        model = EfficientNet.EfficientNet(param)
        net = Efficient_Unet(model=model, up_parm=up_parm_dict[backbone_name])

    return net
예제 #10
0
#
if __name__ == "__main__":
    from config import Config
    from model import EfficientNet
    opt = Config()
    torch.cuda.empty_cache()
    device = torch.device(opt.device)
    criterion = torch.nn.CrossEntropyLoss().cuda()
    model_name = opt.backbone
    model_save_dir = os.path.join(opt.checkpoints_dir, model_name)
    if not os.path.exists(model_save_dir): os.makedirs(model_save_dir)
    logger = get_logger(os.path.join(model_save_dir, 'log.log'))
    logger.info('Using: {}'.format(model_name))
    logger.info('InputSize: {}'.format(opt.input_size))
    logger.info('optimizer: {}'.format(opt.optimizer))
    logger.info('lr_init: {}'.format(opt.lr))
    logger.info('batch size: {}'.format(opt.train_batch_size))
    logger.info('criterion: {}'.format(opt.loss))
    logger.info('Using label smooth: {}'.format(opt.use_smooth_label))
    logger.info('lr_scheduler: {}'.format(opt.lr_scheduler))
    logger.info('Using the GPU: {}'.format(str(opt.gpu_id)))

    model = EfficientNet.from_pretrained(model_name='efficientnet-b7',
                                         num_classes=2)
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=5e-4)
    #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=3, T_mult=2, eta_min=1e-6, last_epoch=-1)
    train_model(model, criterion, optimizer, lr_scheduler=lr_scheduler)
예제 #11
0
    file.write(str(label_no) + '\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights',
                        type=str,
                        default=WEIGTH_PATH,
                        help='the weights file you want to test')
    args = parser.parse_args()

    # net = resnet50()
    # num_ftrs = net.fc.in_features
    # net.fc = nn.Linear(num_ftrs, 38)

    net = EfficientNet.from_pretrained('efficientnet-b0', num_classes=39)

    net = net.cuda()

    print("__Resnet load success")

    net.load_state_dict(torch.load(args.weights), True)
    print(net)

    net.eval()

    # __ Walk the exam folder
    folder_path = DATA_EXAM_PATH
    path_list = os.listdir(folder_path)

    # Output .dat file
예제 #12
0
    print('*' * 30)
    train_loader = get_loader(osp.join(args.data_path, 'images'),
                              osp.join(args.data_path, 'list_attr_celeba.txt'),
                              crop_size=178,
                              image_size=360,
                              batch_size=args.batch_size,
                              mode='all',
                              num_workers=args.num_workers)
    print('*' * 30)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('Running on device:', device)
    print('*' * 30)

    print('Model Architecture')
    model = EfficientNet.from_name('efficientnet-b4')
    model = model.to(device)
    model_params = torchutils.get_model_param_count(model)
    print(model)
    print('Total model parameters:', model_params)
    print('*' * 30)

    # criterion = nn.BCEWithLogitsLoss()
    criterion = nn.MultiLabelSoftMarginLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     threshold=0.005,
            continue
        total_examples += len(target)
        data, target = Variable(data), Variable(target)
        data, target = data.cuda(), target.cuda()

        scores = model(data)
        pred = scores.data.max(1)[1]
        test_correct += pred.eq(target.data).cpu().sum()
    print("Predicted {} out of {} correctly".format(test_correct,
                                                    total_examples))
    return 100.0 * test_correct / (float(total_examples))


if __name__ == '__main__':
    torch.cuda.device(0)
    model = EfficientNet(1.0, 1.0)

    config = CONFIG()

    model = model.cuda()

    avg_loss = list()
    best_accuracy = 0.0

    optimizer = optim.SGD(model.parameters(),
                          lr=LEARNING_RATE,
                          momentum=config.momentum,
                          weight_decay=config.weight_decay)

    train_acc, val_acc = list(), list()
예제 #14
0
def main():
    parser = argparse.ArgumentParser(description='Transfer images to styles.')
    parser.add_argument('--dataset-dir',
                        type=str,
                        dest='data_dir',
                        help='path to content images directory',
                        default='data/img_tiff')
    parser.add_argument('--chk-path',
                        type=str,
                        dest='weight_path',
                        help='path to model weights',
                        default='none')
    parser.add_argument(
        '--device',
        type=str,
        help='type of inference device',
        default='cuda',
    )
    parser.add_argument('--batch-size', type=int, dest='batch_size', default=4)
    parser.add_argument('--epoch', type=int, dest='num_epoch', default=10)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()
    num_epoch = args.num_epoch

    source_images_path = glob.glob(os.path.join(
        args.data_dir, '*'))  # glob.glob("arg.data_dir/*")

    elbow_xray_dataset = ElbowxrayDataset(xlsx_file='data/elbow.xlsx',
                                          root_dir='data/img_tiff/',
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Resize(size=(512,
                                                                      512))
                                          ]))

    dataloader_train = torch.utils.data.DataLoader(elbow_xray_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)

    if args.weight_path == 'none':
        model = EfficientNet.from_pretrained('efficientnet-b6',
                                             num_classes=2).to(device)
        model.train()
    else:
        state = torch.load(args.weight_path)
        model.load_state_dict(state)
        model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = torch.nn.CrossEntropyLoss()  # subject due to change

    for epoch in range(num_epoch):

        for i, data in enumerate(dataloader_train, 0):
            images, labels = data['image'].to(device), data['label'].to(device)

            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if i % 20 == 19:
                print('[Epoch %d/%d, %5d] loss: %.3f' %
                      (epoch + 1, num_epoch, i + 1, loss / 20))

        torch.save(model.state_dict(), 'ckpt/' + str(epoch) + '.pkl')
def main():

    # Manual seed for reproducibility
    torch.manual_seed(363636)

    # Global instances
    global args, use_cuda, device
    # Instantiating the parser
    args = parser.parse_args()
    # Global CUDA flag
    use_cuda = args.cuda and torch.cuda.is_available()
    # Defining device and device's map locationo
    device = torch.device("cuda" if use_cuda else "cpu")
    print('chosen device: ', device)

    # Building the model
    if args.model == 'cifar_micronet':
        print(
            'Building MicroNet for CIFAR with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        if args.dataset == 'CIFAR100':
            num_classes = 100
        elif args.dataset == 'CIFAR10':
            num_classes = 10
        model = micronet(args.dw_multps[0]**args.phi,
                         args.dw_multps[1]**args.phi, num_classes)

    elif args.model == 'image_micronet':
        print(
            'Building MicroNet for ImageNet with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        model = image_micronet(args.dw_multps[0]**args.phi,
                               args.dw_multps[1]**args.phi)

    elif args.model == 'efficientnet-b1':
        print('Building EfficientNet-B1 ...')
        model = EfficientNet.efficientnet_b1()

    elif args.model == 'efficientnet-b2':
        print('Building EfficientNet-B2 ...')
        model = EfficientNet.efficientnet_b2()

    elif args.model == 'efficientnet-b3':
        print('Building EfficientNet-B3 ...')
        model = EfficientNet.efficientnet_b3()

    elif args.model == 'efficientnet-b4':
        print('Building EfficientNet-B4 ...')
        model = EfficientNet.efficientnet_b4()

    elif args.model == 'lenet-5':
        print(
            'Building LeNet-5 with depth multiplier {} and width multiplier {} ...'
            .format(args.dw_multps[0]**args.phi, args.dw_multps[1]**args.phi))
        model = lenet5(d_multiplier=args.dw_multps[0]**args.phi,
                       w_multiplier=args.dw_multps[1]**args.phi)

    for name, param in model.named_parameters():
        print('\n', name)

    # Transfers model to device (GPU/CPU).
    model.to(device)

    # Defining loss function and printing CUDA information (if available)
    if use_cuda:
        print("PyTorch version: ")
        print(torch.__version__)
        print("CUDA Version: ")
        print(torch.version.cuda)
        print("cuDNN version is: ")
        print(cudnn.version())
        cudnn.benchmark = True
        loss_fct = nn.CrossEntropyLoss().cuda()
    else:
        loss_fct = nn.CrossEntropyLoss()

    # Dataloaders for CIFAR, ImageNet and MNIST

    if args.dataset == 'CIFAR100':

        print('Loading CIFAR-100 data ...')
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            root=args.data_path,
            train=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3,
                                       contrast=0.3,
                                       saturation=0.3,
                                       hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]),
            download=True),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=args.data_path,
                              train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  normalize,
                              ])),
            batch_size=args.val_batch_size,
            shuffle=False,
            **kwargs)

    elif args.dataset == 'ImageNet':

        print('Loading ImageNet data ...')
        traindir = os.path.join(args.data_path, 'train')
        valdir = os.path.join(args.data_path, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        if model.__class__.__name__ == 'EfficientNet' or 'efficientnet' in str(
                args.model):
            image_size = EfficientNet.get_image_size(args.model)

        else:
            image_size = args.image_size

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]))
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.val_batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    elif args.dataset == 'MNIST':

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            args.data_path,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_path,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.val_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == 'CIFAR10':

        print('Loading CIFAR-10 data ...')
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True
        } if use_cuda else {}

        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            root=args.data_path,
            train=True,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.3,
                                       contrast=0.3,
                                       saturation=0.3,
                                       hue=0.075),
                transforms.ToTensor(),
                normalize,
                Cutout(n_holes=1, length=16),
            ]),
            download=True),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=args.data_path,
                             train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 normalize,
                             ])),
            batch_size=args.val_batch_size,
            shuffle=False,
            **kwargs)

    else:
        raise NotImplementedError('Undefined dataset name %s' % args.dataset)

    train_w_frozen_assignment(train_loader, val_loader, model, loss_fct)