Exemple #1
0
def build_transforms(vocab, config):
    train_transform = eval_transform = T.Compose([
        ApplyTo(
            "text",
            T.Compose([
                Normalize(),
                VocabEncode(vocab),
                ToTorch(),
            ]),
        ),
        ApplyTo(
            "audio",
            T.Compose([
                LoadAudio(config.sample_rate),
                ToTorch(),
            ]),
        ),
        Extract(["text", "audio"]),
    ])

    return train_transform, eval_transform
Exemple #2
0
def build_transforms(config):
    train_transform = T.Compose(
        [
            LoadImage(T.Resize(config.image_size)),
            ApplyTo(
                "image",
                T.Compose(
                    [
                        T.RandomCrop(config.crop_size),
                        Random8(),
                        T.ColorJitter(0.1, 0.1, 0.1),
                        T.ToTensor(),
                        T.Normalize(mean=MEAN, std=STD),
                    ]
                ),
            ),
            Extract(["image", "meta", "target"]),
        ]
    )
    eval_transform = T.Compose(
        [
            LoadImage(T.Resize(config.image_size)),
            ApplyTo(
                "image",
                T.Compose(
                    [
                        T.CenterCrop(config.crop_size),
                        T.ToTensor(),
                        T.Normalize(mean=MEAN, std=STD),
                    ]
                ),
            ),
            Extract(["image", "meta", "target"]),
        ]
    )

    return train_transform, eval_transform
Exemple #3
0
def build_transforms(config):
    eval_transform = T.Compose([
        LoadImage(T.Resize(config.image_size)),
        ApplyTo(
            "image",
            T.Compose([
                T.CenterCrop(config.crop_size),
                TTA8(),
                Map(
                    T.Compose([
                        T.ToTensor(),
                        T.Normalize(mean=MEAN, std=STD),
                    ])),
                T.Lambda(lambda x: torch.stack(x, 0)),
            ]),
        ),
        Extract(["image", "meta", "id"]),
    ])

    return eval_transform
Exemple #4
0
def main(dataset_path, workers):
    transform = T.Compose([
        ApplyTo(
            ['image'],
            T.Compose([
                SplitInSites(),
                T.Lambda(
                    lambda xs: torch.stack([ToTensor()(x) for x in xs], 0)),
            ])),
        Extract(['image']),
    ])

    train_data = pd.read_csv(os.path.join(dataset_path, 'train.csv'))
    train_data['root'] = os.path.join(dataset_path, 'train')
    test_data = pd.read_csv(os.path.join(dataset_path, 'test.csv'))
    test_data['root'] = os.path.join(dataset_path, 'test')
    data = pd.concat([train_data, test_data])

    stats = {}
    for (exp, plate), group in tqdm(data.groupby(['experiment', 'plate'])):
        dataset = TestDataset(group, transform=transform)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=32,
                                                  num_workers=workers)

        with torch.no_grad():
            images = [images for images, in data_loader]
            images = torch.cat(images, 0)
            mean = images.mean((0, 1, 3, 4))
            std = images.std((0, 1, 3, 4))
            stats[(exp, plate)] = mean, std

            del images, mean, std
            gc.collect()

    torch.save(stats, 'plate_stats.pth')
Exemple #5
0
    Resize(config.resize_size),
    center_crop,
    to_tensor,
])
test_image_transform = T.Compose([
    Resize(config.resize_size),
    center_crop,
    SplitInSites(),
    T.Lambda(lambda xs: torch.stack([to_tensor(x) for x in xs], 0)),
])
train_transform = T.Compose([
    ApplyTo(['image'],
            T.Compose([
                RandomSite(),
                Resize(config.resize_size),
                random_crop,
                RandomFlip(),
                RandomTranspose(),
                to_tensor,
                ChannelReweight(config.aug.channel_reweight),
            ])),
    normalize,
    Extract(['image', 'exp', 'label', 'id']),
])
eval_transform = T.Compose([
    ApplyTo(['image'], infer_image_transform),
    normalize,
    Extract(['image', 'exp', 'label', 'id']),
])
unsup_transform = T.Compose([
    ApplyTo(['image'],
            T.Compose([
Exemple #6
0
        return self.transform(input)

    def reset(self, *args, **kwargs):
        self.transform = self.build_transform(*args, **kwargs)


random_crop = Resetable(RandomCrop)
center_crop = Resetable(CenterCrop)

train_transform = T.Compose([
    ApplyTo(
        ['image'],
        T.Compose([
            RandomSite(),
            Resize(config.resize_size),
            random_crop,
            RandomFlip(),
            RandomTranspose(),
            RandomRotation(180),  # FIXME:
            ToTensor(),
            ChannelReweight(config.aug.channel_weight),
        ])),
    # NormalizeByRefStats(),
    Extract(['image', 'feat', 'label', 'id']),
])
eval_transform = T.Compose([
    ApplyTo(
        ['image'],
        T.Compose([
            RandomSite(),  # FIXME:
            Resize(config.resize_size),
            center_crop,
Exemple #7
0
    center_crop,
    to_tensor,
])
test_image_transform = T.Compose([
    Resize(config.resize_size),
    center_crop,
    SplitInSites(),
    T.Lambda(lambda xs: torch.stack([to_tensor(x) for x in xs], 0)),
])
train_transform = T.Compose([
    ApplyTo(
        ['image'],
        T.Compose([
            RandomSite(),
            Resize(config.resize_size),
            random_crop,
            RandomFlip(),
            RandomTranspose(),
            to_tensor,
            ChannelReweight(config.aug.channel_reweight),
        ])),
    normalize,
    Extract(['image', 'feat', 'exp', 'label', 'id']),
])
eval_transform = T.Compose([
    ApplyTo(
        ['image'],
        infer_image_transform),
    normalize,
    Extract(['image', 'feat', 'exp', 'label', 'id']),
])
Exemple #8
0
def main(experiment_path, dataset_path, config_path, restore_path, workers):
    logging.basicConfig(level=logging.INFO)
    config = Config.from_json(config_path)
    fix_seed(config.seed)

    train_data = pd.concat([
        load_data(os.path.join(dataset_path, 'train-clean-100'),
                  workers=workers),
        load_data(os.path.join(dataset_path, 'train-clean-360'),
                  workers=workers),
    ])
    eval_data = pd.concat([
        load_data(os.path.join(dataset_path, 'dev-clean'), workers=workers),
    ])

    if config.vocab == 'char':
        vocab = CharVocab(CHAR_VOCAB)
    elif config.vocab == 'word':
        vocab = WordVocab(train_data['syms'], 30000)
    elif config.vocab == 'subword':
        vocab = SubWordVocab(10000)
    else:
        raise AssertionError('invalid config.vocab: {}'.format(config.vocab))

    train_transform = T.Compose([
        ApplyTo(['sig'], T.Compose([
            LoadSignal(SAMPLE_RATE),
            ToTensor(),
        ])),
        ApplyTo(['syms'], T.Compose([
            VocabEncode(vocab),
            ToTensor(),
        ])),
        Extract(['sig', 'syms']),
    ])
    eval_transform = train_transform

    train_dataset = TrainEvalDataset(train_data, transform=train_transform)
    eval_dataset = TrainEvalDataset(eval_data, transform=eval_transform)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=BatchSampler(train_data,
                                   batch_size=config.batch_size,
                                   shuffle=True,
                                   drop_last=True),
        num_workers=workers,
        collate_fn=collate_fn)

    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_sampler=BatchSampler(eval_data, batch_size=config.batch_size),
        num_workers=workers,
        collate_fn=collate_fn)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Model(SAMPLE_RATE, len(vocab))
    model_to_save = model
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)
    if restore_path is not None:
        load_weights(model_to_save, restore_path)

    if config.opt.type == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     config.opt.lr,
                                     weight_decay=1e-4)
    elif config.opt.type == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    config.opt.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    else:
        raise AssertionError('invalid config.opt.type {}'.format(
            config.opt.type))

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        len(train_data_loader) * config.epochs)

    # ==================================================================================================================
    # main loop

    train_writer = SummaryWriter(os.path.join(experiment_path, 'train'))
    eval_writer = SummaryWriter(os.path.join(experiment_path, 'eval'))
    best_wer = float('inf')

    for epoch in range(config.epochs):
        if epoch % 10 == 0:
            logging.info(experiment_path)

        # ==============================================================================================================
        # training

        metrics = {
            'loss': Mean(),
            'fps': Mean(),
        }

        model.train()
        t1 = time.time()
        for (sigs, labels), (sigs_mask, labels_mask) in tqdm(
                train_data_loader,
                desc='epoch {} training'.format(epoch),
                smoothing=0.01):
            sigs, labels = sigs.to(device), labels.to(device)
            sigs_mask, labels_mask = sigs_mask.to(device), labels_mask.to(
                device)

            logits, etc = model(sigs, labels[:, :-1], sigs_mask,
                                labels_mask[:, :-1])

            loss = compute_loss(input=logits,
                                target=labels[:, 1:],
                                mask=labels_mask[:, 1:],
                                smoothing=config.label_smoothing)
            metrics['loss'].update(loss.data.cpu().numpy())

            lr = np.squeeze(scheduler.get_lr())

            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()
            scheduler.step()

            t2 = time.time()
            metrics['fps'].update(1 / ((t2 - t1) / sigs.size(0)))
            t1 = t2

        with torch.no_grad():
            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            print('[EPOCH {}][TRAIN] {}'.format(
                epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k])
                                 for k in metrics)))
            for k in metrics:
                train_writer.add_scalar(k, metrics[k], global_step=epoch)
            train_writer.add_scalar('learning_rate', lr, global_step=epoch)

            train_writer.add_image('spectras',
                                   torchvision.utils.make_grid(
                                       etc['spectras'],
                                       nrow=compute_nrow(etc['spectras']),
                                       normalize=True),
                                   global_step=epoch)
            for k in etc['weights']:
                w = etc['weights'][k]
                train_writer.add_image('weights/{}'.format(k),
                                       torchvision.utils.make_grid(
                                           w,
                                           nrow=compute_nrow(w),
                                           normalize=True),
                                       global_step=epoch)

            for i, (true, pred) in enumerate(
                    zip(labels[:, 1:][:4].detach().data.cpu().numpy(),
                        np.argmax(logits[:4].detach().data.cpu().numpy(),
                                  -1))):
                print('{}:'.format(i))
                text = vocab.decode(
                    take_until_token(true.tolist(), vocab.eos_id))
                print(colored(text, 'green'))
                text = vocab.decode(
                    take_until_token(pred.tolist(), vocab.eos_id))
                print(colored(text, 'yellow'))

        # ==============================================================================================================
        # evaluation

        metrics = {
            # 'loss': Mean(),
            'wer': Mean(),
        }

        model.eval()
        with torch.no_grad(), Pool(workers) as pool:
            for (sigs, labels), (sigs_mask, labels_mask) in tqdm(
                    eval_data_loader,
                    desc='epoch {} evaluating'.format(epoch),
                    smoothing=0.1):
                sigs, labels = sigs.to(device), labels.to(device)
                sigs_mask, labels_mask = sigs_mask.to(device), labels_mask.to(
                    device)

                logits, etc = model.infer(sigs,
                                          sigs_mask,
                                          sos_id=vocab.sos_id,
                                          eos_id=vocab.eos_id,
                                          max_steps=labels.size(1) + 10)

                # loss = compute_loss(
                #     input=logits, target=labels[:, 1:], mask=labels_mask[:, 1:], smoothing=config.label_smoothing)
                # metrics['loss'].update(loss.data.cpu().numpy())

                wer = compute_wer(input=logits,
                                  target=labels[:, 1:],
                                  vocab=vocab,
                                  pool=pool)
                metrics['wer'].update(wer)

        with torch.no_grad():
            metrics = {k: metrics[k].compute_and_reset() for k in metrics}
            print('[EPOCH {}][EVAL] {}'.format(
                epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k])
                                 for k in metrics)))
            for k in metrics:
                eval_writer.add_scalar(k, metrics[k], global_step=epoch)

            eval_writer.add_image('spectras',
                                  torchvision.utils.make_grid(
                                      etc['spectras'],
                                      nrow=compute_nrow(etc['spectras']),
                                      normalize=True),
                                  global_step=epoch)
            for k in etc['weights']:
                w = etc['weights'][k]
                eval_writer.add_image('weights/{}'.format(k),
                                      torchvision.utils.make_grid(
                                          w,
                                          nrow=compute_nrow(w),
                                          normalize=True),
                                      global_step=epoch)

        save_model(model_to_save, experiment_path)
        if metrics['wer'] < best_wer:
            best_wer = metrics['wer']
            save_model(model_to_save,
                       mkdir(os.path.join(experiment_path, 'best')))
Exemple #9
0
from config import Config
from stal.dataset import NUM_CLASSES, TestDataset, build_data
from stal.model_cls import Model, Ensemble
from stal.utils import rle_encode
from transforms import ApplyTo, Extract

FOLDS = list(range(1, 5 + 1))
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

test_transform = T.Compose([
    ApplyTo(['image'],
            T.Compose([
                T.ToTensor(), normalize,
                T.Lambda(lambda x: torch.stack([x], 0))
            ])),
    Extract(['image', 'id']),
])


def update_transforms(p):
    assert 0. <= p <= 1.


def worker_init_fn(_):
    seed_python(torch.initial_seed() % 2**32)


def one_hot(input):
Exemple #10
0
    raise AssertionError('invalide normalization {}'.format(config.normalize))

eval_image_transform = T.Compose([
    RandomSite(),
    Resize(config.resize_size),
    center_crop,
    to_tensor,
])
test_image_transform = T.Compose([
    Resize(config.resize_size),
    center_crop,
    SplitInSites(),
    T.Lambda(lambda xs: torch.stack([to_tensor(x) for x in xs], 0)),
])
test_transform = T.Compose([
    ApplyTo(['image'], infer_image_transform),
    normalize,
    Extract(['image', 'feat', 'exp', 'id']),
])


def update_transforms(p):
    if not config.progressive_resize:
        p = 1.

    assert 0. <= p <= 1.

    crop_size = round(224 + (config.crop_size - 224) * p)
    print('update transforms p: {:.2f}, crop_size: {}'.format(p, crop_size))
    random_crop.reset(crop_size)
    center_crop.reset(crop_size)