Exemplo n.º 1
0
def do_images_resize(new_sizes, overwrite=False):
    base_path = prop('images')

    resource_dir = prop('resources')
    assert len(new_sizes) == 2
    new_name = f'images_{new_sizes[0]}x{new_sizes[1]}'

    new_folder = os.path.join(resource_dir, new_name)
    if os.path.exists(new_folder):
        if overwrite:
            os.removedirs(new_folder)
        else:
            print('Skipping, directory', new_folder, 'exists')
            return

    print('Start resizing to', new_folder)
    os.makedirs(new_folder)

    files = list(
        filter(lambda name: name.endswith('.jpg'), os.listdir(base_path)))

    ts = Resize(new_sizes, Image.NEAREST)
    for f in tqdm(files):
        img = Image.open(os.path.join(base_path, f))
        resized = ts(img)
        resized.save(os.path.join(new_folder, f))
Exemplo n.º 2
0
def init_wandb(config, name=None, offline=is_local_env(), tags=None):
    mode = None
    if offline:
        mode = 'dryrun'
    wandb.login(key=prop('wandb.secret'))
    wandb.init(project=prop('wandb.project'),
               config=config,
               dir=prop('wandb.directory'),
               name=name,
               mode=mode,
               tags=tags)
Exemplo n.º 3
0
 def create_run(config, name=None, offline=False, tags=None):
     mode = None
     if offline:
         mode = 'dryrun'
     wandb.login(key=prop('wandb.secret'))
     wandb.init(project=prop('wandb.project'),
                config=config,
                dir=prop('wandb.directory'),
                name=name,
                mode=mode,
                tags=tags)
Exemplo n.º 4
0
    def process_one(self, image, target, model_output_all, masks, image_index,
                    in_batch_index):
        model_out, heatmaps = model_output_all

        heatmap = heatmaps[in_batch_index, 0]
        heatmap /= heatmap.max()

        resize = transforms.Resize((self.image_resize, self.image_resize))

        _, temp_path = tempfile.mkstemp(dir=prop('cache'), suffix='.png')
        plt.imshow(heatmap.detach().cpu().numpy())
        plt.savefig(temp_path)

        data = [
            wandb.Image(common.tensor_to_pil(image), caption='Original image'),
            mask_to_wandb(masks[self.label_to_index_total[self.label]],
                          caption=f'Mask of {self.label}'),
            wandb.Image(resize(Image.open(temp_path)),
                        caption='Normalized gradcam'),
            wandb_image_with_heatmap(image=image,
                                     heatmap=heatmap,
                                     caption='GradCam',
                                     alpha=0.6)
        ]

        wandb.log({f'GCAM_of_{image_index}_{self.label}': data})
Exemplo n.º 5
0
    def dump(self, file=None):
        if file is None:
            file = 'model_stat_' + str(int(round(time.time() * 1000)))

        full_path = os.path.abspath(os.path.join(prop('metrics_dump'), file))
        with open(full_path, 'w') as f:
            json.dump(self.data, f)
        return full_path
Exemplo n.º 6
0
    def _load_masks(self, f_name, shape):
        masks = torch.zeros((len(self.classes), 1, *shape))
        for i, attr in enumerate(self.classes):
            mask_file = f'{f_name[:-len(".jpg")]}_attribute_{attr}.png'
            mask_path = self.path_resolver(prop('masks'), mask_file)
            mask = self.image_loader(mask_path, mode=None)

            mask = self.transforms.cached(mask)
            masks[i] = mask
        return masks
Exemplo n.º 7
0
        def get_one(_a):
            file = ATTR_FILE_MASK.format(get_image_id(_id, id_len=self.id_len),
                                         _a)
            path = self.path_resolver(prop('masks'), file)

            img = self.image_loader(path)
            if self.ts is not None:
                img = self.ts(img)

            mask = img[0]

            return mask.unsqueeze(dim=0)  # we can take only one color
Exemplo n.º 8
0
def cache_file(file_path, to_path=None):
    print('caching', file_path)
    client = paramiko.SSHClient()
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

    client.connect(hostname=prop('ssh.host'),
                   port=prop('ssh.port'),
                   username=prop('ssh.user'))

    sftp = client.open_sftp()

    if to_path is None:
        _, temp_path = tempfile.mkstemp(dir=prop('cache'))
    else:
        temp_path = to_path

    sftp.get(file_path, temp_path)

    sftp.close()
    client.close()
    print('cached to', temp_path)
    return temp_path
Exemplo n.º 9
0
    def __init__(self, model: nn.Module, checkpoint, image_net, random_samples,
                 indices, images_dir, use_cpu, name, image_resize, **kwargs):
        self.checkpoint = checkpoint
        self.image_net = image_net
        self.random_samples = random_samples
        self.indices = indices
        self.image_resize = image_resize

        self.device = common.get_device(cpu_force=use_cpu)
        self.model = model
        self.images_dir = images_dir
        self.use_cpu = use_cpu
        self.name = name

        ch = get_checkpoint(checkpoint, image_net)
        if ch is not None:
            model.load_state_dict(ch, strict=False)
        else:
            print('Using random weights!')

        self.dataset = BasicDataset(labels_csv=dataset_path(
            prop('datasets.test')),
                                    transforms=self.get_image_transforms(),
                                    img_dir=self.images_dir)

        self.mask_loader = MaskLoader(preset_attr=self.dataset.labels(),
                                      ts=self.get_image_transforms())

        self.label_to_index_total = {}
        for i, name in enumerate(self.dataset.labels()):
            self.label_to_index_total[name] = i

        if self.random_samples is None and self.indices is None:
            raise AttributeError(
                'Expected one of `indices` or `random_samples`')

        if self.indices is None and self.random_samples is not None:
            self.indices = np.random.random_integers(low=0,
                                                     high=len(self.dataset) -
                                                     1,
                                                     size=self.random_samples)
Exemplo n.º 10
0
def save_torch_checkpoint(constants: dict, **kwargs):
    res_state = {}
    for k, v in constants.items():
        res_state[k] = v

    suffix = ''
    if 'suffix' in kwargs:
        suffix = f"_{kwargs['suffix']}"

    suffix = f'{suffix}_{datetime.datetime.now().strftime(prop("checkpoint.time_format"))}'
    kwargs.pop('suffix', None)

    for k, v in kwargs.items():
        try:
            state = v.state_dict()
            res_state[k] = state
        except:
            import warnings
            warnings.warn(f"Skipping save of {k}")
    path = os.path.join(prop('checkpoint.directory'),
                        f'checkpoint{suffix}.chp')
    torch.save(res_state, path)
    return path
Exemplo n.º 11
0
def base_exec_parser(run_name, tags=None):
    if tags is None:
        tags = []
    # noinspection PyTypeChecker
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        '--images-dir',
        type=str,
        required=False,
        default=prop('images'),
        help='Location of train images, default value present in config.py')
    parser.add_argument('--use_cpu',
                        action='store_true',
                        required=False,
                        help='Use cpu only')
    parser.add_argument('--name',
                        type=str,
                        required=False,
                        default=run_name,
                        help='WanDB run name')
    parser.add_argument('--image-resize',
                        type=int,
                        required=False,
                        default=256,
                        help='Image resize')

    parser.add_argument('--tags',
                        action='extend',
                        dest='tags',
                        nargs="+",
                        type=str,
                        help='Additional tags to run',
                        default=tags)
    return parser
Exemplo n.º 12
0
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
BATCH_SIZE = 10
FAST_TRAIN = False
N_CLASSES = 5
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20

# %%

data_loaders = get_dataloader(image_resize=256,
                              mean=MEAN,
                              std=STD,
                              fast_train=FAST_TRAIN,
                              batch_size=BATCH_SIZE,
                              img_dir=prop('images256'))

# %%

GAIN_TARGET_LABEL = 'pigment_network'

gain_target = 4  # data_loaders[PHASE_TRAIN].dataset.labels().index(GAIN_TARGET_LABEL)

# %%

device = get_device(cpu_force=is_local_env())
print('Got device', device)

# %%
model = multi_label_resnet50(num_labels=N_CLASSES, pretrained=True)
model = MLGradCamResnet(model=model,
Exemplo n.º 13
0
def attach_logger():
    log_path = os.path.join(
        prop('logs_dir'),
        datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + '.log')
    logging.basicConfig(filename=log_path, level=logging.INFO)
    print('Logs attached to', log_path)
Exemplo n.º 14
0
def generate_checkpoint_path():
    suffix = f'_{datetime.datetime.now().strftime(prop("checkpoint.time_format"))}'
    path = os.path.join(prop('checkpoint.directory'),
                        f'checkpoint{suffix}.chp')
    return path
Exemplo n.º 15
0
def main(cfg: RunConfig):
    torch.set_num_threads(4)
    images_dir = resource_path(config.prop(f'images{cfg.images_size}'),
                               strict=True)

    args, name, tags = get_wandb_starup(cfg)
    print('Got args', args)
    print('Got name', name)
    print('Got tags', tags)

    device = resolve_device(cfg.device)
    print(f'Got device {device}')

    cl_criterion = resolve_criterion(cfg.label, cfg.cl_criterion, device)
    a_criterion = resolve_criterion(cfg.label, cfg.attention_criterion, device)

    ds_index = DsIndex(images_dir=images_dir, masks_dir=config.prop('masks'))
    full_ds = FullDataset(ds_index=ds_index,
                          image_size=cfg.images_size,
                          labels=[cfg.label])
    samples = json.load(open(resource_path(config.prop('sample_indices'))))

    if cfg.train_on_test:
        samples[PHASE_TRAIN] = samples[PHASE_TEST]
        for l in LABELS:
            samples['balance'][l][PHASE_TRAIN] = samples['balance'][l][
                PHASE_TEST]

    if cfg.balanced:
        samples[PHASE_TRAIN] = samples['balance'][cfg.label][PHASE_TRAIN]
        print('Using balance samples')
        for p in PHASES:
            print(len(samples[p]), p, 'samples',
                  'balanced' if p == PHASE_TRAIN else '')

    if cfg.neg_percent != 100:
        samples[PHASE_TRAIN] = samples[f'neg_{cfg.neg_percent}'][
            cfg.label][PHASE_TRAIN]

    datasets = {}
    for p in [PHASE_TRAIN, PHASE_TEST, PHASE_VALIDATION]:
        datasets[p] = BinaryDataset(full_ds,
                                    sample_ids=samples[p],
                                    label=cfg.label)

    model = resnet50_with_cam(num_classes=1,
                              state_dict=image_net_state_dict(),
                              cam_layer=cfg.cam_layer)
    optimizer = Adam(model.parameters(), lr=cfg.lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0.005)
    if cfg.no_scheduler:
        scheduler = None

    suspector = ImageSuspector(max_samples_by_label=5)

    wandb_help.create_run(args, name, tags=tags, offline=False)
    run_experiment(
        datasets,
        cl_criterion,
        a_criterion,
        model.to(device),
        optimizer,
        scheduler,
        device,
        suspector,
        cfg,
    )

    suspector.log_to_wandb(full_ds)
    wandb_help.log_summary_best()

    for ds in datasets.values():
        ds.clean()
Exemplo n.º 16
0
#     transforms=train_transformations(256, DEFAULT_MEAN, DEFAULT_STD),
#     img_dir=prop('images')
# )
#
# img, label, sample, masks = ds[1]
# print(img.shape, label, sample, masks.shape)
#
# import matplotlib.pyplot as plt
#
# data = [
#     grid_image_item(img.permute(1, 2, 0).numpy(), title='orig')
# ]
#
# for i in range(len(masks)):
#     data.append(grid_image_item(masks[i].permute(1, 2, 0).numpy(), title=f'Mask {i}'))
#
# draw_grid(data)
import torch

import torch.nn.functional as F

ds = BasicDatasetV2(
    labels_csv=dataset_path(prop('datasets.test')),
    transforms=val_test_transform(256, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    img_dir=prop('images')
)

image, labels, sample_id, masks, affects = ds[0]

print(masks.shape)
Exemplo n.º 17
0
def resource_path(path, strict=False):
    result_path = os.path.join(config.prop('resources'), path)
    if strict:
        assert os.path.exists(path), f'Path not exist: {result_path}'
    return result_path
Exemplo n.º 18
0
                was_added = True
                test_res[f'UTest/{phase}'] = res
                test_res[f'pValue/{phase}'] = p_value
            except ValueError as e:
                test_res[f'UTest/{phase}'] = str(e)
                test_res[f'pValue/{phase}'] = "undefined"
                was_error = True

        if was_added:
            RESULT.append(test_res)

        if was_error and was_added:
            stat['partial_success'] += 1
            stat['partial_error'] += 1
        elif was_error and not was_added:
            stat['error'] += 1
        else:
            stat['success'] += 1

    else:
        stat['skipped'] += 1

res = pd.DataFrame(RESULT)

res.to_csv(os.path.join(prop('wandb_export'), 'mann.csv'))
print(stat)

# %%

res
Exemplo n.º 19
0
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)


images = {
    'img': 'train19.jpg',
    'mask': 'orig_mask19.png',
    'baseline': 'baseline19.png',
    'gcam': 'gcam19.png'
}

for k, v in images.items():
    images[k] = os.path.join(prop('cache'), v)
    assert os.path.exists(images[k]), f'{k}'

rgb_img = cv2.imread(images['img'], 1)[:, :, ::-1]
rgb_img = np.float32(rgb_img) / 255

import matplotlib.pyplot as plt

plt.axis('off')

# %%


def show_cam(cam_path):
    mask = cv2.resize(cv2.imread(cam_path, 1), (256, 256))[:, :, 0]
Exemplo n.º 20
0
def load_checkpoint(name, full_path=False):
    result_path = name
    if not full_path:
        result_path = os.path.join(prop('checkpoint.directory'), result_path)
    return torch.load(result_path)
Exemplo n.º 21
0
import common
from common import get_device, setup_seed, foreach
from common import to_device, generate_checkpoint_path
from common.dataset_transformations import train_transformations, val_test_transform
from common.scorer import summary_best, ALL_SCORES
from common.utils import is_iter
from config import prop
from dataset.dataset import BasicDatasetV2
from dataset.utils import dataset_path

DEFAULT_EPOCHS = 20
DEFAULT_LR = 1e-3
DEFAULT_IMAGE_RESIZE = 256
DEFAULT_BATCH_SIZE = 10
DEFAULT_IMAGES_DIR = prop('images')
DEFAULT_SEED = 'rand'
DEFAULT_MEAN = (0.70843003, 0.58212194, 0.53605963)
DEFAULT_STD = (0.15741858, 0.1656929, 0.18091279)

VAL_PHASE = 'val'
TRAIN_PHASE = 'train'


def default_parser(run_name: str, tags=None):
    # noinspection PyTypeChecker
    parser = common.base_exec_parser(run_name, tags)
    parser.add_argument('--epochs',
                        type=int,
                        required=False,
                        default=DEFAULT_EPOCHS,
Exemplo n.º 22
0
# calculate mean and std deviation

import numpy as np
from pathlib import Path
import cv2

from config import prop
from tqdm import trange

imageFilesDir = Path(prop('images'))
files = list(imageFilesDir.rglob('*.jpg'))

# Since the std can't be calculated by simply finding it for each image and averaging like
# the mean can be, to get the std we first calculate the overall mean in a first run then
# run it again to get the std.

mean = np.array([0., 0., 0.])
stdTemp = np.array([0., 0., 0.])
std = np.array([0., 0., 0.])

numSamples = len(files)

for i in trange(numSamples, desc='mean'):
    im = cv2.imread(str(files[i]))
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = im.astype(float) / 255.

    for j in range(3):
        mean[j] += np.mean(im[:, :, j])

mean = (mean / numSamples)
Exemplo n.º 23
0
# %%


@mirrored('/mnt/tank/scratch/rbelyaev/checkpoints/')
def get_checkpoint_path(base_dir, file_name):
    return os.path.join(base_dir, file_name)


@mirrored('/mnt/tank/scratch/rbelyaev/isic_masks/')
def mask_loader(base_dir, file_name):
    return os.path.join(base_dir, file_name)


CHECKPOINT_PATH = get_checkpoint_path(
    prop('checkpoint.directory'),
    'checkpoint__img_resize256_20210102_163431.chp')


def gcam():
    model = multi_label_resnet50(True, 5)
    model = load_model_state(CHECKPOINT_PATH,
                             model,
                             is_remote_paths=False,
                             cleanup_cache=False)

    return MLGradCamResnet(model=model,
                           target_layer='layer4.2',
                           model_out_to_pred=None,
                           device=torch.device('cpu'),
                           cam_category=4)
Exemplo n.º 24
0
    def run(self):
        print(f'Creating datasets [{self.image_resize}x{self.image_resize}]')
        train_dataset = BasicDatasetV2(
            labels_csv=dataset_path(prop('datasets.train')),
            transforms=train_transformations(self.image_resize, self.mean,
                                             self.std),
            img_dir=self.images_dir)

        validation_dataset = BasicDatasetV2(
            labels_csv=dataset_path(prop('datasets.validation')),
            transforms=val_test_transform(self.image_resize, self.mean,
                                          self.std),
            img_dir=self.images_dir)

        # Принудительное кэширование датасетов в начале
        self._warmup_ds(train_dataset, 'train')
        self._warmup_ds(validation_dataset, 'validation')

        for i, (train_label, validation_label) in enumerate(
                zip(train_dataset.labels(), validation_dataset.labels())):
            if train_label != validation_label:
                raise RuntimeError(
                    f'Order of train labels is not like validation labels at {i + 1} position'
                )
        print('Done')

        if self.device is None:
            # получаем наиболее свободный гпу
            self.device = get_device(cpu_force=self.use_cpu)
        else:
            self.device = torch.device(self.device)
        print('Got device', self.device)

        # если не заданы аттрибуты для обучения, то используем все из датасета
        if self.only_labels is None:
            self.only_labels = train_dataset.labels()

        # Вспомогательные данные для получения индекса аттрибута по его имени и наоборот
        self.labels_indices = {}
        for i, label in enumerate(train_dataset.labels()):
            self.labels_indices[label] = i

        target_labels_indices = [
            self.labels_indices[label] for label in self.only_labels
        ]
        self.index_to_label = [
            train_dataset.labels()[idx] for idx in target_labels_indices
        ]

        print(
            f'Running for {", ".join(map(str, self.index_to_label))} classes')

        # простые лоадеры без шафла
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=self.batch_size)
        validation_dataloader = DataLoader(validation_dataset,
                                           batch_size=self.batch_size)

        # создаем опптимизатор
        self.optimizer = self.create_optimizer(self.model.parameters(),
                                               self.lr)

        # если нужно, то создаем шедулер
        if not self.no_scheduler:
            self.scheduler = self.create_scheduler(self.optimizer)

        self.model = self.model.to(self.device)
        for e in trange(self.n_epochs, desc='Epochs'):
            self.current_epoch = e
            self.model.train()
            train_saver = self._basic_epoch_block(train_dataloader,
                                                  target_labels_indices)
            self.model.eval()
            validation_saver = self._basic_epoch_block(validation_dataloader,
                                                       target_labels_indices)

            self.on_epoch_end(train_saver=train_saver,
                              validation_saver=validation_saver,
                              epoch=e)

        def do_save(obj, name):
            f = f'{self.save_path}.{name}'
            torch.save(obj.state_dict(), f)
            print(f'{name} saved to {f}')

        do_save(self.optimizer, 'optimizer')
        if not self.no_scheduler:
            do_save(self.scheduler, 'scheduler')
        do_save(self.model, 'model')