Beispiel #1
0
def download_model(saving_path='.'):
    # inception net
    # model = models.Inception3()
    # model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'], model_dir=saving_path, progress=True))

    # resnet
    model = models.ResNet(_Bottleneck, [3, 8, 36, 3])
    model.load_state_dict(model_zoo.load_url(model_urls['resnet152'], model_dir=saving_path, progress=True))
    # save_model(model, 'resnet152.pkl', saving_path)

    # alex net
    model = models.AlexNet()
    model.load_state_dict(model_zoo.load_url(model_urls['alexnet'], model_dir=saving_path, progress=True))
    # save_model(model, 'alexnet.pkl', saving_path)

    # vgg
    model = models.VGG(_vgg_make_layers(_vgg_cfg['E'], batch_norm=True), init_weights=False)
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'], model_dir=saving_path, progress=True))
    # save_model(model, 'vgg19.pkl', saving_path)

    # squeeze net
    model = models.SqueezeNet(version=1.1)
    model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'], model_dir=saving_path, progress=True))
    # save_model(model, 'squeezenet1_1.pkl', saving_path)

    # dense net
    model = models.DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32))
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
    state_dict = model_zoo.load_url(model_urls['densenet201'], model_dir=saving_path, progress=True)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)
    # save_model(model, 'densenet201.pkl', saving_path)

    # googlenet
    kwargs = dict()
    kwargs['transform_input'] = True
    kwargs['aux_logits'] = False
    # if kwargs['aux_logits']:
    #     warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
    #                   'so make sure to train them')
    original_aux_logits = kwargs['aux_logits']
    kwargs['aux_logits'] = True
    kwargs['init_weights'] = False
    model = models.GoogLeNet(**kwargs)
    model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
    if not original_aux_logits:
        model.aux_logits = False
        del model.aux1, model.aux2
        # save_model(model, 'googlenet.pkl', saving_path)

    # resnext
    model = models.resnext101_32x8d(pretrained=False)
    model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'], model_dir=saving_path, progress=True))
Beispiel #2
0
args = parser.parse_args()
print(args)

# /////////////// Model Setup ///////////////

if args.model_name == 'alexnet':
    net = models.AlexNet()
    net.load_state_dict(
        model_zoo.load_url(
            'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
            # model_dir='/share/data/lang/users/dan/.torch/models'))
            model_dir='/share/data/vision-greg2/pytorch_models/alexnet'))
    args.test_bs = 6

elif args.model_name == 'squeezenet1.0':
    net = models.SqueezeNet(version=1.0)
    net.load_state_dict(
        model_zoo.load_url(
            'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
            # model_dir='/share/data/lang/users/dan/.torch/models'))
            model_dir='/share/data/vision-greg2/pytorch_models/squeezenet'))
    args.test_bs = 6

elif args.model_name == 'squeezenet1.1':
    net = models.SqueezeNet(version=1.1)
    net.load_state_dict(
        model_zoo.load_url(
            'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
            # model_dir='/share/data/lang/users/dan/.torch/models'))
            model_dir='/share/data/vision-greg2/pytorch_models/squeezenet'))
    args.test_bs = 6
Beispiel #3
0
def train_dl():

    # Data preprocessing and split
    training_amount = 300

    training_u_amount = 30000

    validation_amount = 10000

    transform = transforms.Compose(
        [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=True,
                                            transform=transform)

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=True,
                                           transform=transform)

    X_train = np.array(trainset.data)
    y_train = np.array(trainset.targets)

    X_test = np.array(testset.data)
    y_test = np.array(testset.targets)

    # Train set / Validation set split
    X_train, X_val, y_train, y_val = train_test_split(
        X_train,
        y_train,
        test_size=validation_amount,
        random_state=1,
        shuffle=True,
        stratify=y_train)

    # Train unsupervised / Train supervised split
    # Train set / Validation set split
    X_train, X_u_train, y_train, y_u_train = train_test_split(
        X_train,
        y_train,
        test_size=training_u_amount,
        random_state=1,
        shuffle=True,
        stratify=y_train)

    X_remain, X_train, y_remain, y_train = train_test_split(
        X_train,
        y_train,
        test_size=training_amount,
        random_state=1,
        shuffle=True,
        stratify=y_train)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # DL related init variables

    # Hyper parameters
    epochs = int(1e3)
    num_classes = 10
    batch_size = 128
    # learning_rate = 1e-4
    learning_rate = 0.002
    min_lr = 1e-12

    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Model definition
    model_filepath = 'weights.best.pt'
    # model = WideResNet(num_classes=num_classes).to(device)

    model = (models.SqueezeNet(num_classes=num_classes)).to(
        device)  # TODO: Define and use "Wide ResNet-28"

    # Training data generators
    # Data
    train_data_gen = mixmatch_wrapper(X_train,
                                      y_train,
                                      X_u_train,
                                      model,
                                      batch_size=batch_size)

    # Validation data generators
    val_data_gen = basic_generator(X_val,
                                   y_val,
                                   batch_size=batch_size,
                                   shuffle=True)

    # Model summary Keras style
    summary(model, (3, 150, 150))

    # Optimization parameters
    # Used criterions
    supervised_criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    consistency_criterion = L_u
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=0.02)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=10,
        verbose=True,
        threshold=0.001,
        threshold_mode='rel',
        cooldown=0,
        min_lr=min_lr,
        eps=1e-08)

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=30, mode='max', verbose=True)
    checkpoint = ModelCheckpoint(checkpoint_fn=model_filepath,
                                 mode='max',
                                 verbose=True)

    # Training level variables

    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_val_losses = []

    # to track the average training acc per epoch as the model trains
    avg_train_acc_s = []
    # to track the average validation acc per epoch as the model trains
    avg_val_acc_s = []

    # Prepare the model for train
    model.train()

    print('\n===== TRAINING =====\n')
    for epoch in range(epochs):

        # Epoch training

        # Epoch level variables
        # to track the training loss as the model trains
        train_losses_item = []
        # to track the validation loss as the model trains
        val_losses_item = []

        # to track the training acc as the model trains
        train_acc_s_item = []
        # to track the validation acc as the model trains
        val_acc_s_item = []

        avg_train_loss = 0
        avg_train_acc = 0
        total_train = 0
        correct_train = 0

        avg_val_loss = 0
        avg_val_acc = 0
        total_val = 0
        correct_val = 0

        ###################
        # train the model #
        ###################

        train_iters = int(training_u_amount / batch_size)
        # TQDM progress bar definition, for visualization purposes
        pbar_train = tqdm(enumerate(range(train_iters)),
                          total=train_iters,
                          unit=" iter",
                          leave=False,
                          file=sys.stdout,
                          desc='Train epoch ' + str(epoch + 1) + '/' +
                          str(epochs) + '   Loss: %.4f   Accuracy: %.3f  ' %
                          (avg_train_loss, avg_val_acc))

        for i, ii in pbar_train:
            # for i, batch in enumerate(train_data_gen.flow()):
            # Epoch Training

            batch = next(train_data_gen)

            X = np.asarray(batch[0])
            p = np.argmax(np.asarray(batch[1]), axis=1)

            U = np.asarray(batch[2])
            q = np.asarray(batch[3])

            # Converting in PyTorch tensors
            X = torch.from_numpy(X.transpose(
                (0, 3, 1, 2))).to(device, dtype=torch.float)
            p = torch.from_numpy(p).to(device, dtype=torch.long)
            U = torch.from_numpy(U.transpose(
                (0, 3, 1, 2))).to(device, dtype=torch.float)
            q = torch.from_numpy(q).to(device, dtype=torch.float)

            model.train()
            # Forward pass: compute predicted y by passing x to the model.
            p_pred = model(X)
            q_pred = model(U)
            q_pred = torch.softmax(q_pred, dim=1)

            # Compute loss
            supervised_loss = supervised_criterion(p_pred, p)
            consistency_loss = consistency_criterion(q_pred, q)

            loss = supervised_loss + 75 * consistency_loss

            # Before the backward pass, use the optimizer object to zero all of the
            # gradients for the variables it will update (which are the learnable
            # weights of the model). This is because by default, gradients are
            # accumulated in buffers( i.e, not overwritten) whenever .backward()
            # is called.
            optimizer.zero_grad()

            # Backward pass: compute gradient of the loss with respect to model
            # parameters
            loss.backward()

            # Calling the step function on an Optimizer makes an update to its
            # parameters
            optimizer.step()

            # Saving losses
            train_losses_item.append(loss.cpu().item())
            avg_train_loss = np.mean(train_losses_item)

            # Accuracy calculation
            _, predicted = torch.max(p_pred.data, 1)
            total_train += p.size(0)
            correct_train += (predicted == p).sum().item()
            train_acc_s_item.append(correct_train / total_train)
            avg_train_acc = np.mean(train_acc_s_item)

            # Update progress bar loss values
            pbar_train.set_description('Train epoch ' + str(epoch + 1) + '/' +
                                       str(epochs) +
                                       '   Loss: %.4f   Accuracy: %.3f  ' %
                                       (avg_train_loss, avg_train_acc))

        # Saving train avg metrics
        avg_train_losses.append(avg_train_loss)
        avg_train_acc_s.append(avg_train_acc)

        ######################
        # validate the model #
        ######################

        model.eval()  # prep model for evaluation

        # TQDM progress bar definition, for visualization purposes
        val_iters = int(validation_amount / batch_size)
        pbar_val = tqdm(enumerate(range(val_iters)),
                        total=val_iters,
                        unit=" iter",
                        leave=False,
                        file=sys.stdout,
                        desc='Validation epoch ' + str(epoch + 1) + '/' +
                        str(epochs) + '   Loss: %.4f   Accuracy: %.3f  ' %
                        (avg_val_loss, avg_val_acc))

        for i, batch in pbar_val:
            # Epoch validation

            X, p = next(val_data_gen)
            X = torch.from_numpy(np.asarray(X).transpose(
                (0, 3, 1, 2))).to(device, dtype=torch.float)
            p = torch.from_numpy(np.asarray(p)).to(device, dtype=torch.long)

            # forward pass: compute predicted outputs by passing inputs to the model
            y_pred = model(X)
            # Compute and print loss.
            loss = supervised_criterion(y_pred, p)
            # record validation loss

            # Saving losses
            val_losses_item.append(loss.cpu().item())
            avg_val_loss = np.mean(val_losses_item)

            # Accuracy calculation
            _, predicted = torch.max(y_pred.data, 1)
            total_val += p.size(0)
            correct_val += (predicted == p).sum().item()
            val_acc_s_item.append(correct_val / total_val)
            avg_val_acc = np.mean(val_acc_s_item)
            val_losses_item.append(loss.item())

            # Update progress bar loss values
            pbar_val.set_description('Validation epoch ' + str(epoch + 1) +
                                     '/' + str(epochs) +
                                     '   Loss: %.4f   Accuracy: %.3f  ' %
                                     (avg_val_loss, avg_val_acc))

        # Saving val avg metrics
        avg_val_losses.append(avg_val_loss)
        avg_val_acc_s.append(avg_val_acc)

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        early_stopping(avg_val_acc)
        checkpoint(avg_val_acc, model)

        # Reduce lr on plateau
        scheduler.step(avg_val_acc)

        if early_stopping.early_stop:
            print("Early stopping")
            break
Beispiel #4
0
 def load_squeezenet(self, model_path, version):
     """加载SqueezeNet1.0预训练模型;"""
     model = models.SqueezeNet(version=version)
     model.load_state_dict(torch.load(model_path))
     return model.features
 pytest.param(
     "models.densenet",
     "DenseNet",
     {},
     [],
     {},
     models.DenseNet(),
     id="DenseNetConf",
 ),
 pytest.param(
     "models.squeezenet",
     "SqueezeNet",
     {},
     [],
     {},
     models.SqueezeNet(),
     id="SqueezeNetConf",
 ),
 pytest.param(
     "models.mnasnet",
     "MNASNet",
     {"alpha": 1.0},
     [],
     {},
     models.MNASNet(alpha=1.0),
     id="MNASNetConf",
 ),
 pytest.param(
     "models.googlenet",
     "GoogLeNet",
     {},