def SENet154(pretrained, num_classes):
    if pretrained:
        model = senet154(pretrained='imagenet')
    else:
        model = senet154()

    model.avg_pool = nn.AdaptiveAvgPool2d(1)
    model.last_linear = nn.Linear(2048, num_classes, bias=True)

    return model
Beispiel #2
0
    def __init__(self,
                 num_classes=1,
                 num_filters=32,
                 pretrained=True,
                 is_deconv=False):
        super().__init__()
        self.num_classes = num_classes

        self.pool = nn.MaxPool2d(2, 2)

        if pretrained:
            self.encoder = senet154(num_classes=1000, pretrained='imagenet')
        else:
            self.encoder = senet154(num_classes=1000, pretrained=None)

        self.conv1 = self.encoder.layer0

        self.conv2 = self.encoder.layer1

        self.conv3 = self.encoder.layer2

        self.conv4 = self.encoder.layer3

        self.conv5 = self.encoder.layer4

        self.center = DecoderBlock(2048,
                                   num_filters * 8,
                                   num_filters * 8,
                                   is_deconv=is_deconv)

        self.dec5 = DecoderBlock(2048 + num_filters * 8,
                                 num_filters * 8,
                                 num_filters * 8,
                                 is_deconv=is_deconv)
        self.dec4 = DecoderBlock(1024 + num_filters * 8,
                                 num_filters * 8,
                                 num_filters * 8,
                                 is_deconv=is_deconv)
        self.dec3 = DecoderBlock(512 + num_filters * 8,
                                 num_filters * 2,
                                 num_filters * 2,
                                 is_deconv=is_deconv)
        self.dec2 = DecoderBlock(256 + num_filters * 2,
                                 num_filters * 2,
                                 num_filters,
                                 is_deconv=is_deconv)
        self.dec1 = DecoderBlock(128 + num_filters,
                                 num_filters,
                                 num_filters,
                                 is_deconv=is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
Beispiel #3
0
def load(info, Continue=False):
    models = [None, None, None, None, None]
    for i in range(5):
        models[i] = pretrainedmodels.senet154(num_classes=1000,
                                              pretrained='imagenet')
        # Modify.
        num_fcin = models[i].last_linear.in_features
        models[i].last_linear = nn.Linear(num_fcin, len(info['classes']))

    # print (model)

    if Continue:
        models = globalconfig.loadmodels(models)
    else:
        for i in range(5):
            models[i].step = 0
            models[i].epochs = 0

    params = []
    for i in range(5):
        models[i] = models[i].to(device=os.environ['device'])
        params_to_update = []
        for name, param in models[i].named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
        params.append(params_to_update)

    modelinfo = {'inputsize': (224, 224)}

    return (models, params, modelinfo)
Beispiel #4
0
    def __init__(self, num_classes):
        super(SENet_154, self).__init__()
        self.name = "SENET_154"
        self.model = model_zoo.senet154()
        self.fc = nn.Linear(2048, num_classes)

        nn.init.xavier_uniform(self.fc.weight, gain=2)
Beispiel #5
0
def set_model (model_name, num_class, neurons_reducer_block=0, comb_method=None, comb_config=None, pretrained=True,
         freeze_conv=False, p_dropout=0.5):

    if pretrained:
        pre_ptm = 'imagenet'
        pre_torch = True
    else:
        pre_torch = False
        pre_ptm = None

    if model_name not in _MODELS:
        raise Exception("The model {} is not available!".format(model_name))

    model = None
    if model_name == 'resnet-50':
        model = MyResnet(models.resnet50(pretrained=pre_torch), num_class, neurons_reducer_block, freeze_conv,
                         comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'resnet-101':
        model = MyResnet(models.resnet101(pretrained=pre_torch), num_class, neurons_reducer_block, freeze_conv,
                         comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'densenet-121':
        model = MyDensenet(models.densenet121(pretrained=pre_torch), num_class, neurons_reducer_block, freeze_conv,
                         comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'vgg-13':
        model = MyVGGNet(models.vgg13_bn(pretrained=pre_torch), num_class, neurons_reducer_block, freeze_conv,
                         comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'vgg-16':
        model = MyVGGNet(models.vgg16_bn(pretrained=pre_torch), num_class, neurons_reducer_block, freeze_conv,
                         comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'vgg-19':
        model = MyVGGNet(models.vgg19_bn(pretrained=pre_torch), num_class, neurons_reducer_block, freeze_conv,
                         comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'mobilenet':
        model = MyMobilenet(models.mobilenet_v2(pretrained=pre_torch), num_class, neurons_reducer_block, freeze_conv,
                         comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'efficientnet-b4':
        if pretrained:
            model = MyEffnet(EfficientNet.from_pretrained(model_name), num_class, neurons_reducer_block, freeze_conv,
                             comb_method=comb_method, comb_config=comb_config)
        else:
            model = MyEffnet(EfficientNet.from_name(model_name), num_class, neurons_reducer_block, freeze_conv,
                             comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'inceptionv4':
        model = MyInceptionV4(ptm.inceptionv4(num_classes=1000, pretrained=pre_ptm), num_class, neurons_reducer_block,
                              freeze_conv, comb_method=comb_method, comb_config=comb_config)

    elif model_name == 'senet':
        model = MySenet(ptm.senet154(num_classes=1000, pretrained=pre_ptm), num_class, neurons_reducer_block,
                        freeze_conv, comb_method=comb_method, comb_config=comb_config)

    return model
Beispiel #6
0
def build_model(model_name):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrained model

    model_name = model_name  # could be fbresnet152 or inceptionresnetv2

    if (model_name == 'senet154'):
        model = pretrainedmodels.senet154(pretrained='imagenet')
    elif (model_name == 'se_resnet152'):
        model = pretrainedmodels.se_resnet152(pretrained='imagenet')
    elif (model_name == 'se_resnext101_32x4d'):
        model = pretrainedmodels.se_resnext101_32x4d(pretrained='imagenet')
    elif (model_name == 'resnet152'):
        model = pretrainedmodels.resnet152(pretrained='imagenet')
    elif (model_name == 'resnet101'):
        model = pretrainedmodels.resnet101(pretrained='imagenet')
    elif (model_name == 'densenet201'):
        model = pretrainedmodels.densenet201(pretrained='imagenet')

    model.to(device)
    for param in model.parameters():
        param.requires_grad = False

    num_ftrs = model.last_linear.in_features

    class CustomModel(nn.Module):
        def __init__(self, model):
            super(CustomModel, self).__init__()
            self.features = nn.Sequential(*list(model.children())[:-1])
            self.classifier = nn.Sequential(
                torch.nn.Linear(num_ftrs, 128),
                torch.nn.Dropout(0.3),  # drop 50% of the neuron
                torch.nn.Linear(128, 7))

        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
            return x

    model = CustomModel(model)
    freeze_layer(model.features)
    num_ftrs = list(model.classifier.children())[-1].out_features

    model.to(device)
    model.name = model_name
    PATH = os.path.abspath(os.path.dirname(__file__))

    PATH_par = os.path.abspath(os.path.join(PATH, os.pardir))
    path_to_model = os.path.join(PATH_par, 'pretrained_model', '128_7')

    model.load_state_dict(
        torch.load(os.path.join(path_to_model, '%s.pth' % (model_name))))
    model.to(device)
    for param in model.parameters():
        param.requires_grad = False

    return model, num_ftrs
Beispiel #7
0
 def __init__(self,
              num_classes,
              pretrained="senet154-c7b49a05.pth",
              dropout=False,
              arcface=False):
     super().__init__()
     self.net = senet154(pretrained=pretrained)
     self.net.last_linear = nn.Linear(2048, num_classes)
Beispiel #8
0
 def __init__(self,
              num_classes,
              pretrained="senet154-c7b49a05.pth",
              dropout=False,
              arcface=False):
     super().__init__()
     self.net = senet154(pretrained=pretrained)
     self.net.avg_pool = AdaptiveConcatPool2d()
     self.net.last_linear = nn.Sequential(Flatten(), SEBlock(2048 * 2),
                                          nn.Linear(2048 * 2, num_classes))
Beispiel #9
0
 def __init__(self, num_classs=100):
     super(Modified_SENet154, self).__init__()
     model = pretrainedmodels.senet154(num_classes=1000, pretrained='imagenet')
     self.num_classs = num_classs
     temp = []
     for i, m in enumerate(model.children()):
         if i <= 6:
             temp.append(m)
         else:
             self.classifier = nn.Linear(in_features=2048, out_features=num_classs)
     self.features = nn.Sequential(*temp)
Beispiel #10
0
def nets(model, num_class):

    if model == 'inceptionv4':
        model = ptm.inceptionv4(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, num_class)
        return model

    if model == 'senet154':
        model = ptm.senet154(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, num_class)
        return model

    if model == 'pnasnet':
        model = ptm.pnasnet5large(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, num_class)
        return model

    if model == 'xception':
        model = ptm.xception(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, num_class)
        return model

    if model == 'incepresv2':
        model = ptm.inceptionresnetv2(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, num_class)
        return model

    if model == 'resnet152':
        model = models.resnet152(pretrained=True)
        model.fc = nn.Linear(2048, num_class)
        return model
        
    if model == 'se_resxt101':
        model = ptm.se_resnext101_32x4d(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, num_class)
        return model
        
    if model == 'nasnet':
        model = ptm.nasnetalarge(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, num_class)
        return model
        
    if model == 'dpn': # 224 input size
        model = ptm.dpn107(num_classes=1000, pretrained='imagenet+5k')
        model.last_linear = nn.Conv2d(model.last_linear.in_channels, num_class,
                                      kernel_size=1, bias=True)
        return model
        
    if model == 'resnext101':# 320 input size
        model = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
        model.fc = nn.Linear(2048, num_class)
        return model  
def get_base_model(config):
    model_name = config.backbone
    pretrained = config.pretrained

    if pretrained is not None and pretrained != 'imagenet':
        weights_path = pretrained
        pretrained = None
    else:
        weights_path = None

    if config.multibranch:
        input_channels = config.multibranch_input_channels
    else:
        input_channels = config.num_slices
        if hasattr(config, 'append_masks') and config.append_masks:
            input_channels *= 2

    _available_models = ['senet154', 'se_resnext50', 'resnet34', 'resnet18']

    if model_name == 'senet154':
        cut_point = -3
        model = nn.Sequential(
            *list(pretrainedmodels.senet154(
                pretrained=pretrained).children())[:cut_point])
        num_features = 2048
    elif model_name == 'se_resnext50':
        cut_point = -2
        model = nn.Sequential(*list(
            pretrainedmodels.se_resnext50_32x4d(
                pretrained=pretrained).children())[:cut_point])
        num_features = 2048
    elif model_name == 'resnet34':
        cut_point = -2
        model = nn.Sequential(
            *list(pretrainedmodels.resnet34(
                pretrained=pretrained).children())[:cut_point])
        num_features = 512
    elif model_name == 'resnet18':
        cut_point = -2
        model = nn.Sequential(
            *list(pretrainedmodels.resnet18(
                pretrained=pretrained).children())[:cut_point])
        num_features = 512
    else:
        raise ValueError('Unavailable backbone, choose one from {}'.format(
            _available_models))

    if model_name in ['senet154', 'se_resnext50']:
        conv1 = model[0].conv1
    else:
        conv1 = model[0]

    if input_channels != 3:
        conv1_weights = deepcopy(conv1.weight)
        new_conv1 = nn.Conv2d(input_channels,
                              conv1.out_channels,
                              kernel_size=conv1.kernel_size,
                              stride=conv1.stride,
                              padding=conv1.padding,
                              bias=conv1.bias)

        if weights_path is None:
            if input_channels == 1:
                new_conv1.weight.data.fill_(0.)
                new_conv1.weight[:, 0, :, :].data.copy_(conv1_weights[:,
                                                                      0, :, :])
            elif input_channels > 3:
                diff = (input_channels - 3) // 2

                new_conv1.weight.data.fill_(0.)
                new_conv1.weight[:,
                                 diff:diff + 3, :, :].data.copy_(conv1_weights)

        if model_name in ['senet154', 'se_resnext50']:
            model[0].conv1 = new_conv1
        else:
            model[0] = new_conv1

    if weights_path is not None:
        if model_name in ['senet154', 'se_resnext50']:
            conv1_str = '0.conv1.weight'
        else:
            conv1_str = '0.weight'
        weights = load_base_weights(weights_path, input_channels, conv1_str)
        model.load_state_dict(weights)

    return model, num_features
Beispiel #12
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_classes,
                 model='seresnext50',
                 pretrained=False,
                 dropout=0,
                 scale=64):
        super().__init__()

        assert model in ['seresnext50', 'seresnext101', 'senet154']

        pretrained_dataset = 'imagenet' if pretrained else None
        if model == 'seresnext50':
            self.model = se_resnext50_32x4d(pretrained=pretrained_dataset)
            inplanes = 64
        elif model == 'seresnext101':
            self.model = se_resnext101_32x4d(pretrained=pretrained_dataset)
            inplanes = 64
        elif model == 'senet154':
            self.model = senet154(pretrained=pretrained_dataset)
            self.model.dropout = nn.Identity()
            inplanes = 128
        else:
            assert False

        layer0_modules = [('conv1',
                           nn.Conv2d(in_channels,
                                     64,
                                     3,
                                     stride=2,
                                     padding=1,
                                     bias=False)), ('bn1', nn.BatchNorm2d(64)),
                          ('relu1', nn.ReLU(inplace=True)),
                          ('conv2',
                           nn.Conv2d(64,
                                     64,
                                     3,
                                     stride=1,
                                     padding=1,
                                     bias=False)), ('bn2', nn.BatchNorm2d(64)),
                          ('relu2', nn.ReLU(inplace=True)),
                          ('conv3',
                           nn.Conv2d(64,
                                     inplanes,
                                     3,
                                     stride=1,
                                     padding=1,
                                     bias=False)),
                          ('bn3', nn.BatchNorm2d(inplanes)),
                          ('relu3', nn.ReLU(inplace=True)),
                          ('pool', nn.MaxPool2d(3, stride=2, ceil_mode=True))]

        self.model.layer0 = nn.Sequential(OrderedDict(layer0_modules))
        self.model.avg_pool = nn.Sequential(CatPool2d(), GaussianNoise(),
                                            nn.Flatten())

        #self.dist = DistanceLayer(self.model.last_linear.in_features, num_classes, middle_feature=None, scale=64, n_centers=5, dropout=0)
        #self.margin = ArcMarginProduct(self.model.last_linear.in_features, num_classes)
        self.mos = MoSLayer(2 * self.model.last_linear.in_features,
                            num_classes,
                            middle_feature=512,
                            prior_feature=1024,
                            scale=64,
                            n_softmax=10,
                            dropout=0)
        self.model.last_linear = nn.Identity()
        replace_relu(self.model)
        #self.out_proj = nn.Sequential(nn.Conv2d(self.model.layer4[-1].conv3.out_channels, out_channels, kernel_size=3, padding=1),
        #                              nn.Sigmoid())

        self.scale = scale
Beispiel #13
0
def build_model(model_name):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrained model

    model_name = model_name # could be fbresnet152 or inceptionresnetv2

    if(model_name == 'senet154'):
        model = pretrainedmodels.senet154(pretrained='imagenet')
    elif(model_name == 'se_resnet152'):
        model = pretrainedmodels.se_resnet152(pretrained='imagenet')
    elif(model_name == 'se_resnext101_32x4d'):
        model = pretrainedmodels.se_resnext101_32x4d(pretrained='imagenet')
    elif(model_name == 'resnet152'):
        model = pretrainedmodels.resnet152(pretrained='imagenet')
    elif(model_name == 'resnet101'):
        model = pretrainedmodels.resnet101(pretrained='imagenet')
    elif(model_name == 'densenet201'):
        model = pretrainedmodels.densenet201(pretrained='imagenet')

    model.to(device)
    for param in model.parameters():
        param.requires_grad = False

    num_ftrs = model.last_linear.in_features

    class CustomModel(nn.Module):
        def __init__(self, model):
            super(CustomModel, self).__init__()
            self.features = nn.Sequential(*list(model.children())[:-1]  )
            self.classifier = nn.Sequential(
                torch.nn.Linear(num_ftrs, 128),
                torch.nn.Dropout(0.3),  # drop 50% of the neuron
                torch.nn.Linear(128, 7)
            )
        
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
            return x
    model = CustomModel(model)
    freeze_layer(model.features)
    model.to(device)
    for param in model.parameters():
        param.requires_grad = False

    
    class CustomModel1(nn.Module):
        def __init__(self, model):
            super(CustomModel1, self).__init__()
            self.features = nn.Sequential(*list(model.children())[:-1])
            self.classifier = nn.Sequential(
                *[list(model.classifier.children())[i] for i in [0]]
            )
        
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
            return x

    CustomModel = CustomModel1(model)
    num_ftrs = list(CustomModel.classifier.children())[-1].out_features
    CustomModel.to(device)
    return CustomModel, num_ftrs
Beispiel #14
0
    args = parser.parse_args()
    params = vars(args)
    os.makedirs(params["output_dir"], exist_ok=True)
    if params['model'] == 'inception_v3':
        C, H, W = 3, 299, 299
        model = pretrainedmodels.inceptionv3(pretrained='imagenet')
        load_image_fn = utils.LoadTransformImage(model)

    elif params['model'] == 'resnet152':
        C, H, W = 3, 224, 224
        model = pretrainedmodels.resnet152(pretrained='imagenet')
        load_image_fn = utils.LoadTransformImage(model)

    elif params['model'] == 'senet154':
        C, H, W = 3, 224, 224
        model = pretrainedmodels.senet154(pretrained='imagenet')
        load_image_fn = utils.LoadTransformImage(model)

    elif params['model'] == 'inception_v4':
        C, H, W = 3, 299, 299
        model = pretrainedmodels.inceptionv4(num_classes=1000,
                                             pretrained='imagenet')
        load_image_fn = utils.LoadTransformImage(model)

    else:
        print("doesn't support %s" % (params['model']))

    model.last_linear = utils.Identity()
    model = nn.DataParallel(model)

    model = model.cuda()
Beispiel #15
0
def main(train_root, train_csv, train_split, val_root, val_csv, val_split,
         epochs, aug, model_name, batch_size, num_workers, val_samples,
         early_stopping_patience, n_classes, weighted_loss, balanced_loader,
         _run):
    assert (model_name
            in ('inceptionv4', 'resnet152', 'densenet161', 'senet154'))

    AUGMENTED_IMAGES_DIR = os.path.join(fs_observer.dir, 'images')
    CHECKPOINTS_DIR = os.path.join(fs_observer.dir, 'checkpoints')
    BEST_MODEL_PATH = os.path.join(CHECKPOINTS_DIR, 'model_best.pth')
    LAST_MODEL_PATH = os.path.join(CHECKPOINTS_DIR, 'model_last.pth')
    for directory in (AUGMENTED_IMAGES_DIR, CHECKPOINTS_DIR):
        os.makedirs(directory)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if model_name == 'inceptionv4':
        model = ptm.inceptionv4(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, n_classes)
        aug['size'] = 299
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'resnet152':
        model = models.resnet152(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, n_classes)
        aug['size'] = 224
        aug['mean'] = [0.485, 0.456, 0.406]
        aug['std'] = [0.229, 0.224, 0.225]
    elif model_name == 'densenet161':
        model = models.densenet161(pretrained=True)
        model.classifier = nn.Linear(model.classifier.in_features, n_classes)
        aug['size'] = 224
        aug['mean'] = [0.485, 0.456, 0.406]
        aug['std'] = [0.229, 0.224, 0.225]
    elif model_name == 'senet154':
        model = ptm.senet154(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, n_classes)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    model.to(device)

    augs = Augmentations(**aug)
    model.aug_params = aug

    train_ds = CSVDatasetWithName(train_root,
                                  train_csv,
                                  'image',
                                  'label',
                                  transform=augs.tf_transform,
                                  add_extension='.jpg',
                                  split=train_split)
    val_ds = CSVDatasetWithName(val_root,
                                val_csv,
                                'image',
                                'label',
                                transform=augs.tf_transform,
                                add_extension='.jpg',
                                split=val_split)

    datasets = {'train': train_ds, 'val': val_ds}

    if balanced_loader:
        data_sampler = sampler.WeightedRandomSampler(train_ds.sampler_weights,
                                                     len(train_ds))
        shuffle = False
    else:
        data_sampler = None
        shuffle = True

    dataloaders = {
        'train':
        DataLoader(datasets['train'],
                   batch_size=batch_size,
                   shuffle=shuffle,
                   num_workers=num_workers,
                   sampler=data_sampler,
                   worker_init_fn=set_seeds),
        'val':
        DataLoader(datasets['val'],
                   batch_size=batch_size,
                   shuffle=False,
                   num_workers=num_workers,
                   worker_init_fn=set_seeds),
    }

    if weighted_loss:
        criterion = nn.CrossEntropyLoss(
            weight=torch.Tensor(datasets['train'].class_weights_list).cuda())
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(),
                          lr=0.001,
                          momentum=0.9,
                          weight_decay=0.001)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=0.1,
                                                     min_lr=1e-5,
                                                     patience=10)

    metrics = {
        'train': pd.DataFrame(columns=['epoch', 'loss', 'acc']),
        'val': pd.DataFrame(columns=['epoch', 'loss', 'acc'])
    }

    best_val_loss = 1000.0
    epochs_without_improvement = 0
    batches_per_epoch = None

    for epoch in range(epochs):
        print('train epoch {}/{}'.format(epoch + 1, epochs))
        epoch_train_result = train_epoch(device, model, dataloaders, criterion,
                                         optimizer, 'train', batches_per_epoch)

        metrics['train'] = metrics['train'].append(
            {
                **epoch_train_result, 'epoch': epoch
            }, ignore_index=True)
        print('train', epoch_train_result)

        epoch_val_result = train_epoch(device, model, dataloaders, criterion,
                                       optimizer, 'val', batches_per_epoch)

        metrics['val'] = metrics['val'].append(
            {
                **epoch_val_result, 'epoch': epoch
            }, ignore_index=True)
        print('val', epoch_val_result)

        scheduler.step(epoch_val_result['loss'])

        if epoch_val_result['loss'] < best_val_loss:
            best_val_loss = epoch_val_result['loss']
            epochs_without_improvement = 0
            torch.save(model, BEST_MODEL_PATH)
            print('Best loss at epoch {}'.format(epoch))
        else:
            epochs_without_improvement += 1

        print('-' * 40)

        if epochs_without_improvement > early_stopping_patience:
            torch.save(model, LAST_MODEL_PATH)
            break

        if epoch == (epochs - 1):
            torch.save(model, LAST_MODEL_PATH)

    for phase in ['train', 'val']:
        metrics[phase].epoch = metrics[phase].epoch.astype(int)
        metrics[phase].to_csv(os.path.join(fs_observer.dir, phase + '.csv'),
                              index=False)

    print('Best validation loss: {}'.format(best_val_loss))

    # TODO: return more metrics
    return {'max_val_acc': metrics['val']['acc'].max()}
Beispiel #16
0
def main(train_root, train_csv, val_root, val_csv, test_root, test_csv, epochs,
         aug, model_name, batch_size, num_workers, val_samples, test_samples,
         early_stopping_patience, limit_data, images_per_epoch, _run):
    assert (model_name
            in ('inceptionv4', 'resnet152', 'densenet161', 'senet154'))

    AUGMENTED_IMAGES_DIR = os.path.join(fs_observer.dir, 'images')
    CHECKPOINTS_DIR = os.path.join(fs_observer.dir, 'checkpoints')
    BEST_MODEL_PATH = os.path.join(CHECKPOINTS_DIR, 'model_best.pth')
    LAST_MODEL_PATH = os.path.join(CHECKPOINTS_DIR, 'model_last.pth')
    RESULTS_CSV_PATH = os.path.join('results', 'results.csv')
    EXP_NAME = _run.meta_info['options']['--name']
    EXP_ID = _run._id

    for directory in (AUGMENTED_IMAGES_DIR, CHECKPOINTS_DIR):
        os.makedirs(directory)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if model_name == 'inceptionv4':
        model = ptm.inceptionv4(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = 299
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'resnet152':
        model = models.resnet152(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, 2)
        aug['size'] = 224
        aug['mean'] = [0.485, 0.456, 0.406]
        aug['std'] = [0.229, 0.224, 0.225]
    elif model_name == 'densenet161':
        model = models.densenet161(pretrained=True)
        model.classifier = nn.Linear(model.classifier.in_features, 2)
        aug['size'] = 224
        aug['mean'] = [0.485, 0.456, 0.406]
        aug['std'] = [0.229, 0.224, 0.225]
    elif model_name == 'senet154':
        model = ptm.senet154(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    model.to(device)

    augs = Augmentations(**aug)
    model.aug_params = aug

    datasets = {
        'samples':
        CSVDataset(train_root,
                   train_csv,
                   'image_id',
                   'melanoma',
                   transform=augs.tf_augment,
                   add_extension='.jpg',
                   limit=(400, 433)),
        'train':
        CSVDataset(train_root,
                   train_csv,
                   'image_id',
                   'melanoma',
                   transform=augs.tf_transform,
                   add_extension='.jpg',
                   random_subset_size=limit_data),
        'val':
        CSVDatasetWithName(val_root,
                           val_csv,
                           'image_id',
                           'melanoma',
                           transform=augs.tf_transform,
                           add_extension='.jpg'),
        'test':
        CSVDatasetWithName(test_root,
                           test_csv,
                           'image_id',
                           'melanoma',
                           transform=augs.tf_transform,
                           add_extension='.jpg'),
        'test_no_aug':
        CSVDatasetWithName(test_root,
                           test_csv,
                           'image_id',
                           'melanoma',
                           transform=augs.no_augmentation,
                           add_extension='.jpg'),
        'test_144':
        CSVDatasetWithName(test_root,
                           test_csv,
                           'image_id',
                           'melanoma',
                           transform=augs.inception_crop,
                           add_extension='.jpg'),
    }

    dataloaders = {
        'train':
        DataLoader(datasets['train'],
                   batch_size=batch_size,
                   shuffle=True,
                   num_workers=num_workers,
                   worker_init_fn=set_seeds),
        'samples':
        DataLoader(datasets['samples'],
                   batch_size=batch_size,
                   shuffle=False,
                   num_workers=num_workers,
                   worker_init_fn=set_seeds),
    }

    save_images(datasets['samples'], to=AUGMENTED_IMAGES_DIR, n=32)
    sample_batch, _ = next(iter(dataloaders['samples']))
    save_image(make_grid(sample_batch, padding=0),
               os.path.join(AUGMENTED_IMAGES_DIR, 'grid.jpg'))

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(),
                          lr=0.001,
                          momentum=0.9,
                          weight_decay=0.001)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=[10],
                                               gamma=0.1)
    metrics = {
        'train': pd.DataFrame(columns=['epoch', 'loss', 'acc', 'auc']),
        'val': pd.DataFrame(columns=['epoch', 'loss', 'acc', 'auc'])
    }

    best_val_auc = 0.0
    best_epoch = 0
    epochs_without_improvement = 0
    if images_per_epoch:
        batches_per_epoch = images_per_epoch // batch_size
    else:
        batches_per_epoch = None

    for epoch in range(epochs):
        print('train epoch {}/{}'.format(epoch + 1, epochs))
        epoch_train_result = train_epoch(device, model, dataloaders, criterion,
                                         optimizer, batches_per_epoch)

        metrics['train'] = metrics['train'].append(
            {
                **epoch_train_result, 'epoch': epoch
            }, ignore_index=True)
        print('train', epoch_train_result)

        epoch_val_result, _ = test_with_augmentation(model, datasets['val'],
                                                     device, num_workers,
                                                     val_samples)

        metrics['val'] = metrics['val'].append(
            {
                **epoch_val_result, 'epoch': epoch
            }, ignore_index=True)
        print('val', epoch_val_result)
        print('-' * 40)

        scheduler.step()

        if epoch_val_result['auc'] > best_val_auc:
            best_val_auc = epoch_val_result['auc']
            best_epoch = epoch
            epochs_without_improvement = 0
            torch.save(model, BEST_MODEL_PATH)
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement > early_stopping_patience:
            torch.save(model, LAST_MODEL_PATH)
            break

        if epoch == (epochs - 1):
            torch.save(model, LAST_MODEL_PATH)

    for phase in ['train', 'val']:
        metrics[phase].epoch = metrics[phase].epoch.astype(int)
        metrics[phase].to_csv(os.path.join(fs_observer.dir, phase + '.csv'),
                              index=False)

    # Run testing
    test_result, _ = test_with_augmentation(torch.load(BEST_MODEL_PATH),
                                            datasets['test'], device,
                                            num_workers, test_samples)
    print('test', test_result)

    test_noaug_result, _ = test_with_augmentation(torch.load(BEST_MODEL_PATH),
                                                  datasets['test_no_aug'],
                                                  device, num_workers, 1)
    print('test (no augmentation)', test_noaug_result)

    test_144crop_result, _ = test_with_augmentation(
        torch.load(BEST_MODEL_PATH), datasets['test_144'], device, num_workers,
        1)
    print('test (144-crop)', test_144crop_result)

    with open(RESULTS_CSV_PATH, 'a') as file:
        file.write(','.join((EXP_NAME, str(EXP_ID), str(best_epoch),
                             str(best_val_auc), str(test_noaug_result['auc']),
                             str(test_result['auc']),
                             str(test_144crop_result['auc']))) + '\n')

    return (test_noaug_result['auc'], test_result['auc'],
            test_144crop_result['auc'])
def get_senet154_pretrained_model(num_classes):
    pretrained_senet154 = senet154()
    pretrained_senet154.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
    pretrained_senet154.last_linear = nn.Linear(2048, num_classes)
    return pretrained_senet154
Beispiel #18
0
def senet154(num_classes=1000,pretrained=None):
    model = pretrainedmodels.senet154(num_classes=num_classes,pretrained=pretrained)
    return model
Beispiel #19
0
def initialize_model(model_name,
                     num_classes,
                     feature_extract,
                     use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "se_resnext50_32x4d":
        model_ft = pretrainedmodels.se_resnext50_32x4d(num_classes=1000,
                                                       pretrained='imagenet')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.last_linear.in_features
        # print('model_ft.last_linear.in_features: ', num_ftrs)
        model_ft.last_linear = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "se_resnext101_32x4d":
        model_ft = pretrainedmodels.se_resnext101_32x4d(num_classes=1000,
                                                        pretrained='imagenet')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.last_linear.in_features
        model_ft.last_linear = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "PolyNet":
        model_ft = pretrainedmodels.polynet(num_classes=1000,
                                            pretrained='imagenet')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.last_linear.in_features
        model_ft.last_linear = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "SENet154":
        model_ft = pretrainedmodels.senet154(num_classes=1000,
                                             pretrained='imagenet')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.last_linear.in_features
        model_ft.last_linear = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet-18":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet-50":
        """ Resnet50
        """
        model_ft = models.resnet50(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        # model_ft.fc = nn.Linear(num_ftrs, num_classes)
        model_ft.fc = nn.Sequential(nn.Dropout(p=0.4),
                                    nn.Linear(num_ftrs, num_classes))
        input_size = 224

    elif model_name == "resnet-152":
        """ Resnet152
        """
        model_ft = models.resnet152(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512,
                                           num_classes,
                                           kernel_size=(1, 1),
                                           stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size
def main(args):
    print(torch.cuda.device_count(), 'gpus available')
    # 0. Initializing training
    if args.fold is not None:
        if args.fold_prefix is None:
            print('Please add fold-prefix to arguments')
            return
        folder_name = f'{args.name}_{args.fold_prefix // 10}{args.fold_prefix % 10}_fold{args.fold}'
        checkpoint_path = os.path.join(args.data, 'checkpoints', folder_name)
        log_path = os.path.join(args.data, 'logs', folder_name)
        if not os.path.exists(checkpoint_path):
            os.mkdir(checkpoint_path)
    else:
        for i in range(100):
            folder_name = f'{args.name}_{i // 10}{i % 10}'
            checkpoint_path = os.path.join(args.data, 'checkpoints', folder_name)
            log_path = os.path.join(args.data, 'logs', folder_name)
            if not os.path.exists(checkpoint_path):
                os.mkdir(checkpoint_path)
                break
    if args.checkpoint is None:
        training_state = {
            'best_checkpoints': [],
            'best_scores': [],
            'epoch': []
        }
    else:
        # loading checkpoint
        from_checkpoint = f'{args.name}_{args.checkpoint // 10}{args.checkpoint % 10}'
        parent_checkpoint_path = os.path.join(args.data, 'checkpoints', from_checkpoint)
        training_state = torch.load(os.path.join(parent_checkpoint_path, 'training_state.pth'))
        training_state['from_checkpoint'] = from_checkpoint
        print(f'Using checkpoint {from_checkpoint}')
    print(f'Results can be found in {folder_name}')
    writer = SummaryWriter(log_dir=log_path)

    # 1. prepare data & models
    if args.name == 'senet154':
        crop_size = 224
    else:
        crop_size = CROP_SIZE
    train_transforms = transforms.Compose([
        # HorizontalFlip(p=0.5),
        ScaleMinSideToSize((crop_size, crop_size)),
        CropCenter(crop_size),
        TransformByKeys(transforms.ToPILImage(), ('image',)),
        TransformByKeys(transforms.ToTensor(), ('image',)),
        TransformByKeys(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ('image',)),
    ])
    test_transforms = transforms.Compose([
        ScaleMinSideToSize((crop_size, crop_size)),
        CropCenter(crop_size),
        TransformByKeys(transforms.ToPILImage(), ('image',)),
        TransformByKeys(transforms.ToTensor(), ('image',)),
        TransformByKeys(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ('image',)),
    ])
    albu_transforms = albu.Compose([
                            albu.Blur(p=0.1),
                            albu.MultiplicativeNoise(p=0.1, per_channel=True),
                            albu.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20, p=0.2),
                            albu.ChannelShuffle(p=0.2)
                       ],
                      keypoint_params=albu.KeypointParams(format='xy'))
    print('\nTransforms:')
    print(albu_transforms)
    print(train_transforms)

    print('\nReading data...')
    datasets = torch.load(os.path.join(args.data, 'datasets.pth'))
    for d in datasets:
        datasets[d].transforms = train_transforms
    if args.fold is None:
        print('Using predefined data split')
        train_dataset = datasets['train_dataset']
        val_dataset = datasets['val_dataset']
    else:
        print(f'Using fold {args.fold}')
        train_dataset = FoldDatasetDataset(datasets['train_dataset'], datasets['val_dataset'], train_transforms,
                           albu_transforms, split='train', fold=args.fold, seed=42)
        val_dataset = FoldDatasetDataset(datasets['train_dataset'], datasets['val_dataset'], train_transforms,
                           None, split='val', fold=args.fold, seed=42)

    test_dataset = datasets['test_dataset']
    test_dataset.transforms = test_transforms

    train_dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=32, pin_memory=True,
                                       shuffle=True, drop_last=True)
    val_dataloader = data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=32, pin_memory=True,
                                     shuffle=False, drop_last=False)

    print('Creating model...')
    device = torch.device('cuda: 0') if args.gpu else torch.device('cpu')
    if args.name == 'senet154':
        model = pretrainedmodels.senet154(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2 * NUM_PTS, bias=True)
    else:
        model = models.resnext50_32x4d(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True)
    model = nn.DataParallel(model)
    print(f'Using {torch.cuda.device_count()} gpus')
    if args.checkpoint is not None:
        model.load_state_dict(training_state['best_checkpoints'][0])
    model.to(device)

    # optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, nesterov=True)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    # optimizer = RAdam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    # scheduler = None
    print(f'Optimizer: {optimizer}')
    print(f'Scheduler: {scheduler}')
    loss_fn = fnn.mse_loss

    # 2. train & validate
    print('Ready for training...')
    if args.checkpoint is None:
        start_epoch = 0
        best_val_loss = np.inf
    else:
        start_epoch = training_state['epoch'][0]
        best_val_loss = training_state['best_scores'][0]

    for epoch in range(start_epoch, start_epoch + args.epochs):
        train_loss = train(model, train_dataloader, loss_fn, optimizer, device, writer, epoch)
        val_loss = validate(model, val_dataloader, loss_fn, device, writer, epoch, scheduler)
        print('Epoch #{:2}:\ttrain loss: {:5.2}\tval loss: {:5.2}'.format(epoch, train_loss, val_loss))
        print(f'Learning rate = {optimizer.param_groups[0]["lr"]}')
        if len(training_state['best_scores']) == 0:
            training_state['best_checkpoints'].append(model.state_dict())
            training_state['best_scores'].append(val_loss)
            training_state['epoch'].append(epoch)
            with open(os.path.join(checkpoint_path, 'training_state.pth'), 'wb') as fp:
                torch.save(training_state, fp)
        elif len(training_state['best_scores']) < 3 or val_loss < training_state['best_scores'][-1]:
            cur_val_index = 0
            for cur_val_index in range(len(training_state['best_scores'])):
                if val_loss < training_state['best_scores'][cur_val_index]:
                    break
            training_state['best_scores'].insert(cur_val_index, val_loss)
            training_state['best_checkpoints'].insert(cur_val_index, model.state_dict())
            training_state['epoch'].insert(cur_val_index, epoch)
            if len(training_state['best_scores']) > 3:
                training_state['best_scores'] = training_state['best_scores'][:3]
                training_state['best_checkpoints'] = training_state['best_checkpoints'][:3]
                training_state['epoch'] = training_state['epoch'][:3]
            with open(os.path.join(checkpoint_path, 'training_state.pth'), 'wb') as fp:
                torch.save(training_state, fp)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            with open(os.path.join(checkpoint_path, f'{args.name}_best.pth'), 'wb') as fp:
                torch.save(model.state_dict(), fp)
    print('Training finished')
    print(f'Min val loss = {training_state["best_scores"]} at epoch {training_state["epoch"]}')
    print()

    # 3. predict
    test_dataloader = data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=16, pin_memory=True,
                                      shuffle=False, drop_last=False)

    with open(os.path.join(checkpoint_path, f'{args.name}_best.pth'), 'rb') as fp:
        best_state_dict = torch.load(fp, map_location='cpu')
        model.load_state_dict(best_state_dict)

    test_predictions = predict(model, test_dataloader, device)
    with open(os.path.join(checkpoint_path, f'{args.name}_test_predictions.pkl'), 'wb') as fp:
        pickle.dump({'image_names': test_dataset.image_names,
                     'landmarks': test_predictions}, fp)

    create_submission(args.data, test_predictions, os.path.join(checkpoint_path, f'{args.name}_submit.csv'))
def main(train_root, train_csv, val_root, val_csv, test_root, test_csv,
         epochs, aug, model_name, batch_size, num_workers, val_samples,
         test_samples, early_stopping_patience, limit_data, images_per_epoch,
         split_id, _run):
    assert(model_name in ('inceptionv4', 'resnet152', 'densenet161',
                          'senet154', 'pnasnet5large', 'nasnetalarge',
                          'xception', 'squeezenet', 'resnext', 'dpn',
                          'inceptionresnetv2', 'mobilenetv2'))

    cv2.setNumThreads(0)

    AUGMENTED_IMAGES_DIR = os.path.join(fs_observer.dir, 'images')
    CHECKPOINTS_DIR = os.path.join(fs_observer.dir, 'checkpoints')
    BEST_MODEL_PATH = os.path.join(CHECKPOINTS_DIR, 'model_best.pth')
    LAST_MODEL_PATH = os.path.join(CHECKPOINTS_DIR, 'model_last.pth')
    RESULTS_CSV_PATH = os.path.join('results', 'results.csv')
    EXP_NAME = _run.meta_info['options']['--name']
    EXP_ID = _run._id

    for directory in (AUGMENTED_IMAGES_DIR, CHECKPOINTS_DIR):
        os.makedirs(directory)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if model_name == 'inceptionv4':
        model = ptm.inceptionv4(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = 299
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'resnet152':
        model = models.resnet152(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, 2)
        aug['size'] = 224
        aug['mean'] = [0.485, 0.456, 0.406]
        aug['std'] = [0.229, 0.224, 0.225]
    elif model_name == 'densenet161':
        model = models.densenet161(pretrained=True)
        model.classifier = nn.Linear(model.classifier.in_features, 2)
        aug['size'] = 224
        aug['mean'] = [0.485, 0.456, 0.406]
        aug['std'] = [0.229, 0.224, 0.225]
    elif model_name == 'senet154':
        model = ptm.senet154(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'squeezenet':
        model = ptm.squeezenet1_1(num_classes=1000, pretrained='imagenet')
        model.last_conv = nn.Conv2d(
            512, 2, kernel_size=(1, 1), stride=(1, 1))
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'pnasnet5large':
        model = ptm.pnasnet5large(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'nasnetalarge':
        model = ptm.nasnetalarge(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'xception':
        model = ptm.xception(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'dpn':
        model = ptm.dpn131(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Conv2d(model.last_linear.in_channels, 2,
                                      kernel_size=1, bias=True)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'resnext':
        model = ptm.resnext101_64x4d(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'inceptionresnetv2':
        model = ptm.inceptionresnetv2(num_classes=1000, pretrained='imagenet')
        model.last_linear = nn.Linear(model.last_linear.in_features, 2)
        aug['size'] = model.input_size[1]
        aug['mean'] = model.mean
        aug['std'] = model.std
    elif model_name == 'mobilenetv2':
        model = MobileNetV2()
        model.load_state_dict(torch.load('./auglib/models/mobilenet_v2.pth'))
        model.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(model.last_channel, 2),
        )
        aug['size'] = 224
        aug['mean'] = [0.485, 0.456, 0.406]
        aug['std'] = [0.229, 0.224, 0.225]
    model.to(device)

    augs = Augmentations(**aug)
    model.aug_params = aug

    datasets = {
        'samples': CSVDataset(train_root, train_csv, 'image_id', 'melanoma',
                              transform=augs.tf_augment, add_extension='.jpg',
                              limit=(400, 433)),
        'train': CSVDataset(train_root, train_csv, 'image_id', 'melanoma',
                            transform=augs.tf_transform, add_extension='.jpg',
                            random_subset_size=limit_data),
        'val': CSVDatasetWithName(
            val_root, val_csv, 'image_id', 'melanoma',
            transform=augs.tf_transform, add_extension='.jpg'),
        'test': CSVDatasetWithName(
            test_root, test_csv, 'image_id', 'melanoma',
            transform=augs.tf_transform, add_extension='.jpg'),
        'test_no_aug': CSVDatasetWithName(
            test_root, test_csv, 'image_id', 'melanoma',
            transform=augs.no_augmentation, add_extension='.jpg'),
        'test_144': CSVDatasetWithName(
            test_root, test_csv, 'image_id', 'melanoma',
            transform=augs.inception_crop, add_extension='.jpg'),
    }

    dataloaders = {
        'train': DataLoader(datasets['train'], batch_size=batch_size,
                            shuffle=True, num_workers=num_workers,
                            worker_init_fn=set_seeds),
        'samples': DataLoader(datasets['samples'], batch_size=batch_size,
                              shuffle=False, num_workers=num_workers,
                              worker_init_fn=set_seeds),
    }

    save_images(datasets['samples'], to=AUGMENTED_IMAGES_DIR, n=32)
    sample_batch, _ = next(iter(dataloaders['samples']))
    save_image(make_grid(sample_batch, padding=0),
               os.path.join(AUGMENTED_IMAGES_DIR, 'grid.jpg'))

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(),
                          lr=0.001,
                          momentum=0.9,
                          weight_decay=0.001)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1,
                                                     min_lr=1e-5,
                                                     patience=8)
    metrics = {
        'train': pd.DataFrame(columns=['epoch', 'loss', 'acc', 'auc']),
        'val': pd.DataFrame(columns=['epoch', 'loss', 'acc', 'auc'])
    }

    best_val_auc = 0.0
    best_epoch = 0
    epochs_without_improvement = 0
    if images_per_epoch:
        batches_per_epoch = images_per_epoch // batch_size
    else:
        batches_per_epoch = None

    for epoch in range(epochs):
        print('train epoch {}/{}'.format(epoch+1, epochs))
        epoch_train_result = train_epoch(
            device, model, dataloaders, criterion, optimizer,
            batches_per_epoch)

        metrics['train'] = metrics['train'].append(
            {**epoch_train_result, 'epoch': epoch}, ignore_index=True)
        print('train', epoch_train_result)

        epoch_val_result, _ = test_with_augmentation(
            model, datasets['val'], device, num_workers, val_samples)

        metrics['val'] = metrics['val'].append(
            {**epoch_val_result, 'epoch': epoch}, ignore_index=True)
        print('val', epoch_val_result)
        print('-' * 40)

        scheduler.step(epoch_val_result['loss'])

        if epoch_val_result['auc'] > best_val_auc:
            best_val_auc = epoch_val_result['auc']
            best_val_result = epoch_val_result
            best_epoch = epoch
            epochs_without_improvement = 0
            torch.save(model, BEST_MODEL_PATH)
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement > early_stopping_patience:
            last_val_result = epoch_val_result
            torch.save(model, LAST_MODEL_PATH)
            break

        if epoch == (epochs-1):
            last_val_result = epoch_val_result
            torch.save(model, LAST_MODEL_PATH)

    for phase in ['train', 'val']:
        metrics[phase].epoch = metrics[phase].epoch.astype(int)
        metrics[phase].to_csv(os.path.join(fs_observer.dir, phase + '.csv'),
                              index=False)

    # Run testing
    # TODO: reduce code repetition
    test_result, preds = test_with_augmentation(
        torch.load(BEST_MODEL_PATH), datasets['test'], device,
        num_workers, test_samples)
    print('[best] test', test_result)

    test_noaug_result, preds_noaug = test_with_augmentation(
        torch.load(BEST_MODEL_PATH), datasets['test_no_aug'], device,
        num_workers, 1)
    print('[best] test (no augmentation)', test_noaug_result)

    test_result_last, preds_last = test_with_augmentation(
        torch.load(LAST_MODEL_PATH), datasets['test'], device,
        num_workers, test_samples)
    print('[last] test', test_result_last)

    test_noaug_result_last, preds_noaug_last = test_with_augmentation(
        torch.load(LAST_MODEL_PATH), datasets['test_no_aug'], device,
        num_workers, 1)
    print('[last] test (no augmentation)', test_noaug_result_last)

    # Save predictions
    preds.to_csv(os.path.join(fs_observer.dir, 'test-aug-best.csv'),
                 index=False, columns=['image', 'label', 'score'])
    preds_noaug.to_csv(os.path.join(fs_observer.dir, 'test-noaug-best.csv'),
                 index=False, columns=['image', 'label', 'score'])
    preds_last.to_csv(os.path.join(fs_observer.dir, 'test-aug-last.csv'),
                 index=False, columns=['image', 'label', 'score'])
    preds_noaug_last.to_csv(os.path.join(fs_observer.dir, 'test-noaug-last.csv'),
                 index=False, columns=['image', 'label', 'score'])

    # TODO: Avoid repetition.
    #       use ordereddict, or create a pandas df before saving
    with open(RESULTS_CSV_PATH, 'a') as file:
        file.write(','.join((
            EXP_NAME,
            str(EXP_ID),
            str(split_id),
            str(best_epoch),
            str(best_val_result['loss']),
            str(best_val_result['acc']),
            str(best_val_result['auc']),
            str(best_val_result['avp']),
            str(best_val_result['sens']),
            str(best_val_result['spec']),
            str(last_val_result['loss']),
            str(last_val_result['acc']),
            str(last_val_result['auc']),
            str(last_val_result['avp']),
            str(last_val_result['sens']),
            str(last_val_result['spec']),
            str(best_val_auc),
            str(test_result['auc']),
            str(test_result_last['auc']),
            str(test_result['acc']),
            str(test_result_last['acc']),
            str(test_result['spec']),
            str(test_result_last['spec']),
            str(test_result['sens']),
            str(test_result_last['sens']),
            str(test_result['avp']),
            str(test_result_last['avp']),
            str(test_noaug_result['auc']),
            str(test_noaug_result_last['auc']),
            str(test_noaug_result['acc']),
            str(test_noaug_result_last['acc']),
            str(test_noaug_result['spec']),
            str(test_noaug_result_last['spec']),
            str(test_noaug_result['sens']),
            str(test_noaug_result_last['sens']),
            str(test_noaug_result['avp']),
            str(test_noaug_result_last['avp']),
            )) + '\n')

    return (test_noaug_result['auc'],
            test_result['auc'],
            )