Exemple #1
0
    def get_databunch(self, valid_pct=0.1):
        df = self._get_df_from_file()

        data_normaliser = preprocessing.MinMaxScaler()
        data = data_normaliser.fit_transform(df["adj_close"].values.reshape(
            -1, 1))

        X = np.array([
            data[i:i + self.day_count].copy()
            for i in range(len(data) - self.day_count)
        ])
        y = np.array([
            data[:, 0][i + self.day_count].copy()
            for i in range(len(data) - self.day_count)
        ])
        tabular_data = np.array([
            df.drop(["adj_close", "Elapsed"],
                    axis=1).iloc[i + self.day_count - 1]
            for i in range(len(data) - self.day_count)
        ])
        y = np.expand_dims(y, -1)

        n = int(len(X) * (1 - valid_pct))

        train_ds = StockDataset(X[:n], tabular_data[:n], y[:n])
        valid_ds = StockDataset(X[n:], tabular_data[n:], y[n:])
        return DataBunch.create(train_ds, valid_ds, bs=self.batch_size)
Exemple #2
0
 def set_data(self, tr_data, val_data, bs=None):
     """Set data sources for this learner."""
     tr_ds = self.create_dataset(tr_data)
     val_ds = self.create_dataset(val_data)
     bs = ifnone(bs, defaults.batch_size)
     self.data = DataBunch(tr_ds.as_loader(bs=bs), val_ds.as_loader(bs=bs))
     if 'data' in self.parameters:
         del self.parameters['data']  # force recomputation
Exemple #3
0
 def after_prepare_data_hook(self):
     """Put to databunch here"""
     logger.debug("kernel use device %s", self.device)
     self.data = DataBunch.create(self.train_dataset,
                                  self.validation_dataset,
                                  bs=self.config.batch_size,
                                  device=self.device,
                                  num_workers=self.config.num_workers)
Exemple #4
0
def test_unet_without_gt(learn,
                         picture_input,
                         downsample=8,
                         batch_size=12,
                         picture=False):

    picture_input.unsqueeze_(dim=0)
    #picture_input = picture_input[:,:,0:224,0:224]
    picture_input = F.interpolate(picture_input,
                                  size=(224, 224),
                                  mode='bilinear',
                                  align_corners=True).float()
    #picture_input = torch.cat([picture_input, picture_input, picture_input], dim=1)
    my_dataset = TensorDataset(picture_input,
                               picture_input)  # create your datset
    my_dataloader = DataLoader(my_dataset,
                               batch_size=batch_size)  # create your dataloader
    my_databunch = DataBunch(train_dl=my_dataloader,
                             test_dl=my_dataloader,
                             valid_dl=my_dataloader)
    learn.data = my_databunch
    output = learn.get_preds(ds_type=DatasetType.Valid)[0]
    output_inter = F.interpolate(output,
                                 scale_factor=1.0 / downsample,
                                 mode='nearest')

    if picture:
        import matplotlib.pyplot as plt

        idx = 0
        plt.figure()
        plt.subplot(221)
        aa = picture_input[idx, :, :, :].data.numpy()
        im_out = np.transpose(aa, (1, 2, 0))
        plt.imshow(im_out)
        plt.title('input')

        plt.subplot(222)
        aa = output[idx, :-1, :, :].data.numpy()
        im_out = np.transpose(aa, (1, 2, 0))
        plt.imshow(im_out)
        plt.title('output')

        plt.subplot(223)
        aa = output_inter[idx, :-1, :, :].data.numpy()
        im_out = np.transpose(aa, (1, 2, 0))
        plt.imshow(im_out)
        plt.title('output_downsampled')

        plt.show()

    return output_inter.data.cpu().numpy()
Exemple #5
0
def check_data(data_bunch: DataBunch, vocab: Vocab, verbose: bool,
               allow_unks: bool) -> None:
    first_batch = data_bunch.one_batch()[0]

    if not allow_unks and not contains_no_value(first_batch,
                                                UNKNOWN_TOKEN_INDEX):
        raise ValueError(
            f"Unknown is found : {[vocab.textify(seq) for seq in first_batch]}"
        )
    if verbose:
        logger.info(f'Displaying the first batch:\n{first_batch}')
        token_seqs = [vocab.textify(seq) for seq in first_batch]
        logger.info(pformat(token_seqs))
Exemple #6
0
def run(ini_file='tinyimg.ini',
        data_in_dir='./../../dataset',
        model_cfg='../cfg/vgg-tiny.cfg',
        model_out_dir='./models',
        epochs=30,
        lr=3.0e-5,
        batch_sz=256,
        num_worker=4,
        log_freq=20,
        use_gpu=True):
    # Step 1: parse config
    cfg = parse_cfg(ini_file,
                    data_in_dir=data_in_dir,
                    model_cfg=model_cfg,
                    model_out_dir=model_out_dir,
                    epochs=epochs,
                    lr=lr,
                    batch_sz=batch_sz,
                    log_freq=log_freq,
                    num_worker=num_worker,
                    use_gpu=use_gpu)
    print_cfg(cfg)

    # Step 2: create data sets and loaders
    train_ds, val_ds = build_train_val_datasets(cfg, in_memory=True)
    train_loader, val_loader = DLFactory.create_train_val_dataloader(
        cfg, train_ds, val_ds)

    # Step 3: create model
    model = MFactory.create_model(cfg)

    # Step 4: train/valid
    # This demos our approach can be easily intergrate with our app framework
    device = get_device(cfg)
    data = DataBunch(train_loader, val_loader, device=device)
    learn = Learner(data,
                    model,
                    loss_func=torch.nn.CrossEntropyLoss(),
                    metrics=accuracy)
    #  callback_fns=[partial(EarlyStoppingCallback, monitor='accuracy', min_delta=0.01, patience=2)])

    # lr_find(learn, start_lr=1e-7, end_lr=10)
    # learn.recorder.plot()
    # lrs_losses = [(lr, loss) for lr, loss in zip(learn.recorder.lrs, learn.recorder.losses)]
    # min_lr = min(lrs_losses[10:-5], key=lambda x: x[1])[0]
    # lr = min_lr/10.0
    # plt.show()
    # print(f'Minimal lr rate is {min_lr} propose init lr {lr}')
    # fit_one_cycle(learn, epochs, lr)

    learn.fit(epochs, lr)
Exemple #7
0
def main():
    model = PSMNet(args.maxdisp, args.mindisp).cuda()
    if args.load_model is not None:
        if args.load is not None:
            warn('args.load is not None. load_model will be covered by load.')
        ckpt = torch.load(args.load_model, 'cpu')
        if 'model' in ckpt.keys():
            pretrained = ckpt['model']
        elif 'state_dict' in ckpt.keys():
            pretrained = ckpt['state_dict']
        else:
            raise RuntimeError()
        pretrained = {
            k.replace('module.', ''): v
            for k, v in pretrained.items()
        }
        model.load_state_dict(pretrained)
    train_dl = DataLoader(KITTIRoiDataset(args.data_dir, 'train',
                                          args.resolution, args.maxdisp,
                                          args.mindisp),
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=args.workers)
    val_dl = DataLoader(KITTIRoiDataset(args.data_dir, 'val', args.resolution,
                                        args.maxdisp, args.mindisp),
                        batch_size=args.batch_size,
                        num_workers=args.workers)

    loss_fn = PSMLoss()

    databunch = DataBunch(train_dl, val_dl, device='cuda')
    learner = Learner(databunch,
                      model,
                      loss_func=loss_fn,
                      model_dir=args.model_dir)
    learner.callbacks = [
        DistributedSaveModelCallback(learner),
        TensorBoardCallback(learner)
    ]
    if num_gpus > 1:
        learner.to_distributed(get_rank())
    if args.load is not None:
        learner.load(args.load)
    if args.mode == 'train':
        learner.fit(args.epochs, args.maxlr)
    elif args.mode == 'train_oc':
        fit_one_cycle(learner, args.epochs, args.maxlr)
    else:
        raise ValueError('args.mode not supported.')
Exemple #8
0
def test_fastai_pruning_callback(tmpdir):
    # type: (typing.Any) -> None

    train_x = np.zeros((16, 20), np.float32)
    train_y = np.zeros((16, ), np.int64)
    valid_x = np.zeros((4, 20), np.float32)
    valid_y = np.zeros((4, ), np.int64)
    train_ds = ArrayDataset(train_x, train_y)
    valid_ds = ArrayDataset(valid_x, valid_y)

    data_bunch = DataBunch.create(
        train_ds=train_ds,
        valid_ds=valid_ds,
        test_ds=None,
        path=tmpdir,
        bs=1  # batch size
    )

    def objective(trial):
        # type: (optuna.trial.Trial) -> float

        model = nn.Sequential(nn.Linear(20, 1), nn.Sigmoid())
        learn = Learner(
            data_bunch,
            model,
            metrics=[accuracy],
            callback_fns=[
                partial(FastAIPruningCallback,
                        trial=trial,
                        monitor="valid_loss")
            ],
        )

        learn.fit(1)

        return 1.0

    study = optuna.create_study(pruner=DeterministicPruner(True))
    study.optimize(objective, n_trials=1)
    assert study.trials[0].state == optuna.structs.TrialState.PRUNED

    study = optuna.create_study(pruner=DeterministicPruner(False))
    study.optimize(objective, n_trials=1)
    assert study.trials[0].state == optuna.structs.TrialState.COMPLETE
    assert study.trials[0].value == 1.0
Exemple #9
0
    def __init__(self, tr_data, val_data, **kwargs):
        learner_args, other_args = self._partition_args(**kwargs)

        # self.parameters is a bunch of metadata used for tracking.
        # see keys at the top of traintracker.py
        self.parameters = {
            "code_marker":
            "add mask to input",  # hardwired description of significant code update
            "arch": self.__class__.
            __name__  # may be overridden by subclasses to add more info
        }

        bs = defaults.batch_size
        tr_ds = self.create_dataset(tr_data)
        val_ds = self.create_dataset(val_data)
        databunch = DataBunch(tr_ds.as_loader(bs=bs), val_ds.as_loader(bs=bs))
        model = self.create_model(**other_args)
        TrainTracker.default_model_id(model)

        super().__init__(databunch, model, **learner_args)

        self.callback_fns.insert(0, TrainTracker)
def do_train(
    cfg,
    model,
    train_dl,
    valid_dl,
    optimizer,
    loss_fn,
    metrics=[],
    callbacks: list = [],
):

    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    data_bunch = DataBunch(train_dl, valid_dl)
    learn = Learner(data_bunch, model, loss_func=loss_fn)
    callbacks.append(LoggingLog(learn, "template_model.train"))
    if metrics:
        learn.metrics = metrics
    learn.fit_one_cycle(epochs, cfg.SOLVER.BASE_LR)
def classifier_data_bunch(config):
    split_df = pd.read_csv(config.split_csv)
    if config.debug_run: split_df = split_df.loc[:200]
    train_df = split_df[split_df['is_valid']==False].reset_index(drop=True)
    valid_df = split_df[split_df['is_valid']==True].reset_index(drop=True)

    if config.load_valid_crops:
        valid_df_crops = []
        for i in range(len(valid_df)):
            for j in range(1, 8):
                crop_id = valid_df.loc[i, 'ImageId_ClassId'].replace('.jpg', '_c{}.jpg'.format(j))
                valid_df_crops.append({'ImageId_ClassId': crop_id, '1': valid_df.loc[i, '1'],
                                     '2': valid_df.loc[i, '2'], '3': valid_df.loc[i, '3'], 
                                     '4': valid_df.loc[i, '4'], 'is_valid': valid_df.loc[i, 'is_valid']})
        valid_df = pd.DataFrame(valid_df_crops)

    train_tf = alb_transform_train(config.imsize)
    valid_tf = alb_transform_test(config.imsize)

    train_ds = SteelClassifierDataset(train_df, transforms=train_tf)
    valid_ds = SteelClassifierDataset(valid_df, transforms=valid_tf)
    data = DataBunch.create(train_ds, valid_ds, bs=config.batch_size,
                            num_workers=config.num_workers)
    return data
Exemple #12
0
def run_ner(
        lang: str = 'eng',
        log_dir: str = 'logs',
        task: str = NER,
        batch_size: int = 1,
        epochs: int = 1,
        dataset: str = 'data/conll-2003/',
        loss: str = 'cross',
        max_seq_len: int = 128,
        do_lower_case: bool = False,
        warmup_proportion: float = 0.1,
        rand_seed: int = None,
        ds_size: int = None,
        data_bunch_path: str = 'data/conll-2003/db',
        tuned_learner: str = None,
        do_train: str = False,
        do_eval: str = False,
        save: bool = False,
        nameX: str = 'ner',
        mask: tuple = ('s', 's'),
):
    name = "_".join(
        map(str, [
            nameX, task, lang, mask[0], mask[1], loss, batch_size, max_seq_len,
            do_train, do_eval
        ]))
    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    init_logger(log_dir, name)

    if rand_seed:
        random.seed(rand_seed)
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(rand_seed)

    trainset = dataset + lang + '/train.txt'
    devset = dataset + lang + '/dev.txt'
    testset = dataset + lang + '/test.txt'

    bert_model = 'bert-base-cased' if lang == 'eng' else 'bert-base-multilingual-cased'
    print(f'Lang: {lang}\nModel: {bert_model}\nRun: {name}')
    model = BertForTokenClassification.from_pretrained(bert_model,
                                                       num_labels=len(VOCAB),
                                                       cache_dir='bertm')
    if tuned_learner:
        print('Loading pretrained learner: ', tuned_learner)
        model.bert.load_state_dict(torch.load(tuned_learner))

    model = torch.nn.DataParallel(model)
    model_lr_group = bert_layer_list(model)
    layers = len(model_lr_group)
    kwargs = {'max_seq_len': max_seq_len, 'ds_size': ds_size, 'mask': mask}

    train_dl = DataLoader(dataset=NerDataset(trainset,
                                             bert_model,
                                             train=True,
                                             **kwargs),
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=partial(pad, train=True))

    dev_dl = DataLoader(dataset=NerDataset(devset, bert_model, **kwargs),
                        batch_size=batch_size,
                        shuffle=False,
                        collate_fn=pad)

    test_dl = DataLoader(dataset=NerDataset(testset, bert_model, **kwargs),
                         batch_size=batch_size,
                         shuffle=False,
                         collate_fn=pad)

    data = DataBunch(train_dl=train_dl,
                     valid_dl=dev_dl,
                     test_dl=test_dl,
                     collate_fn=pad,
                     path=Path(data_bunch_path))

    train_opt_steps = int(len(train_dl.dataset) / batch_size) * epochs
    optim = BertAdam(model.parameters(),
                     lr=0.01,
                     warmup=warmup_proportion,
                     t_total=train_opt_steps)

    loss_fun = ner_loss_func if loss == 'cross' else partial(ner_loss_func,
                                                             zero=True)
    metrics = [Conll_F1()]

    learn = Learner(
        data,
        model,
        BertAdam,
        loss_func=loss_fun,
        metrics=metrics,
        true_wd=False,
        layer_groups=model_lr_group,
        path='learn' + nameX,
    )

    learn.opt = OptimWrapper(optim)

    lrm = 1.6

    # select set of starting lrs
    lrs_eng = [0.01, 5e-4, 3e-4, 3e-4, 1e-5]
    lrs_deu = [0.01, 5e-4, 5e-4, 3e-4, 2e-5]

    startlr = lrs_eng if lang == 'eng' else lrs_deu
    results = [['epoch', 'lr', 'f1', 'val_loss', 'train_loss', 'train_losses']]
    if do_train:
        learn.freeze()
        learn.fit_one_cycle(1, startlr[0], moms=(0.8, 0.7))
        learn.freeze_to(-3)
        lrs = learn.lr_range(slice(startlr[1] / (1.6**15), startlr[1]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
        learn.freeze_to(-6)
        lrs = learn.lr_range(slice(startlr[2] / (1.6**15), startlr[2]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
        learn.freeze_to(-12)
        lrs = learn.lr_range(slice(startlr[3] / (1.6**15), startlr[3]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
        learn.unfreeze()
        lrs = learn.lr_range(slice(startlr[4] / (1.6**15), startlr[4]))
        learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))

    if do_eval:
        res = learn.validate(test_dl, metrics=metrics)
        met_res = [f'{m.__name__}: {r}' for m, r in zip(metrics, res[1:])]
        print(f'Validation on TEST SET:\nloss {res[0]}, {met_res}')
        results.append(['val', '-', res[1], res[0], '-', '-'])

    with open(log_dir / (name + '.csv'), 'a') as resultFile:
        wr = csv.writer(resultFile)
        wr.writerows(results)
Exemple #13
0
import cv2
import torch
from fastai.vision import Learner
from fastai.basic_data import DataBunch
from torch.utils.data import DataLoader

from unet import Unet
from data import Dataset
from run_check import should_stop

batch_size = 64
databunch = DataBunch(
    DataLoader(Dataset(16 + 64 * 4),
               batch_size=batch_size,
               shuffle=True,
               num_workers=6,
               pin_memory=False),
    DataLoader(Dataset(16, 16 + 64 * 4),
               batch_size=batch_size,
               num_workers=2,
               pin_memory=True))

model = Unet(1, 3, n=4)
learner = Learner(databunch, model, loss_func=torch.nn.MSELoss())
test_data = list(Dataset(0, 16))
test_x = torch.stack([a[0] for a in test_data]).cuda()
test_y = [np.array(a[1]) for a in test_data]

epoch = -1
while not should_stop():
    epoch += 1
    print('Epoch:', epoch)
Exemple #14
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--pregenerated_data', type=Path, required=True)
    parser.add_argument('--output_dir', type=Path, required=True)
    parser.add_argument("--bert_model", type=str, required=True,
                        choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
                                 "bert-base-multilingual-cased", "bert-base-chinese"])
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument("--reduce_memory", action="store_true",
                        help="Store training data as on-disc memmaps to massively reduce memory usage")

    parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                        "0 (default value): dynamic loss scaling.\n"
                        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help="random seed for initialization")
    args = parser.parse_args()

    assert args.pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).")
            print("This script will loop over the available data, but training diversity may be negatively impacted.")
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs
    print(samples_per_epoch)

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if args.seed:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logging.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
    args.output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(
        total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)

    # Prepare model
    model = BertForPreTraining.from_pretrained(args.bert_model)
    model = torch.nn.DataParallel(model)

    # Prepare optimizer
    optimizer = BertAdam

    train_dataloader = DataLoader(
        PregeneratedData(args.pregenerated_data,  tokenizer,args.epochs, args.train_batch_size),
        batch_size=args.train_batch_size,
    )

    data = DataBunch(train_dataloader,train_dataloader)
    global_step = 0
    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {total_train_examples}")
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)
    def loss(x, y):
        return x.mean()

    learn = Learner(data, model, optimizer,
                    loss_func=loss,
                    true_wd=False,
                    path='learn',
                    layer_groups=bert_layer_list(model),
    )

    lr= args.learning_rate
    layers = len(bert_layer_list(model))
    lrs = learn.lr_range(slice(lr/(2.6**4), lr))
    for epoch in range(args.epochs):
        learn.fit_one_cycle(1, lrs, wd=0.01)
        # save model at half way point
        if epoch == args.epochs//2:
            savem = learn.model.module.bert if hasattr(learn.model, 'module') else learn.model.bert
            output_model_file = args.output_dir / (f"pytorch_fastai_model_{args.bert_model}_{epoch}.bin")
            torch.save(savem.state_dict(), str(output_model_file))
            print(f'Saved bert to {output_model_file}')

    savem = learn.model.module.bert if hasattr(learn.model, 'module') else learn.model.bert
    output_model_file = args.output_dir / (f"pytorch_fastai_model_{args.bert_model}_{args.epochs}.bin")
    torch.save(savem.state_dict(), str(output_model_file))
    print(f'Saved bert to {output_model_file}')
                   nn.Linear(n_channels[3], n_classes)]
        
        self.features = nn.Sequential(*layers)
        
    def forward(self, x): return self.features(x)
    
def wrn_22(): 
    return WideResNet(n_groups=3, N=3, n_classes=10, k=6)

model = wrn_22()

from fastai.basic_data import DataBunch
from fastai.train import Learner
from fastai.metrics import accuracy

data = DataBunch.create(train_ds, valid_ds, bs=batch_size, path='./data/cifar10')
learner = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy])
learner.clip = 0.1 # gradient is clipped to be in range of [-0.1, 0.1]

# Find best learning rate
learner.lr_find()
learner.recorder.plot() # select lr with largest negative gradient (about 5e-3)

# Training
epochs = 1
lr = 5e-3
wd = 1e-4

import time

t0 = time.time()
def run_ner(
        lang: str = 'eng',
        log_dir: str = 'logs',
        task: str = NER,
        batch_size: int = 1,
        lr: float = 5e-5,
        epochs: int = 1,
        dataset: str = 'data/conll-2003/',
        loss: str = 'cross',
        max_seq_len: int = 128,
        do_lower_case: bool = False,
        warmup_proportion: float = 0.1,
        grad_acc_steps: int = 1,
        rand_seed: int = None,
        fp16: bool = False,
        loss_scale: float = None,
        ds_size: int = None,
        data_bunch_path: str = 'data/conll-2003/db',
        bertAdam: bool = False,
        freez: bool = False,
        one_cycle: bool = False,
        discr: bool = False,
        lrm: int = 2.6,
        div: int = None,
        tuned_learner: str = None,
        do_train: str = False,
        do_eval: str = False,
        save: bool = False,
        name: str = 'ner',
        mask: tuple = ('s', 's'),
):
    name = "_".join(
        map(str, [
            name, task, lang, mask[0], mask[1], loss, batch_size, lr,
            max_seq_len, do_train, do_eval
        ]))

    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    init_logger(log_dir, name)

    if rand_seed:
        random.seed(rand_seed)
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(rand_seed)

    trainset = dataset + lang + '/train.txt'
    devset = dataset + lang + '/dev.txt'
    testset = dataset + lang + '/test.txt'

    bert_model = 'bert-base-cased' if lang == 'eng' else 'bert-base-multilingual-cased'
    print(f'Lang: {lang}\nModel: {bert_model}\nRun: {name}')
    model = BertForTokenClassification.from_pretrained(bert_model,
                                                       num_labels=len(VOCAB),
                                                       cache_dir='bertm')

    model = torch.nn.DataParallel(model)
    model_lr_group = bert_layer_list(model)
    layers = len(model_lr_group)
    kwargs = {'max_seq_len': max_seq_len, 'ds_size': ds_size, 'mask': mask}

    train_dl = DataLoader(dataset=NerDataset(trainset,
                                             bert_model,
                                             train=True,
                                             **kwargs),
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=partial(pad, train=True))

    dev_dl = DataLoader(dataset=NerDataset(devset, bert_model, **kwargs),
                        batch_size=batch_size,
                        shuffle=False,
                        collate_fn=pad)

    test_dl = DataLoader(dataset=NerDataset(testset, bert_model, **kwargs),
                         batch_size=batch_size,
                         shuffle=False,
                         collate_fn=pad)

    data = DataBunch(train_dl=train_dl,
                     valid_dl=dev_dl,
                     test_dl=test_dl,
                     collate_fn=pad,
                     path=Path(data_bunch_path))

    loss_fun = ner_loss_func if loss == 'cross' else partial(ner_loss_func,
                                                             zero=True)
    metrics = [Conll_F1()]

    learn = Learner(
        data,
        model,
        BertAdam,
        loss_func=loss_fun,
        metrics=metrics,
        true_wd=False,
        layer_groups=None if not freez else model_lr_group,
        path='learn',
    )

    # initialise bert adam optimiser
    train_opt_steps = int(len(train_dl.dataset) / batch_size) * epochs
    optim = BertAdam(model.parameters(),
                     lr=lr,
                     warmup=warmup_proportion,
                     t_total=train_opt_steps)

    if bertAdam: learn.opt = OptimWrapper(optim)
    else: print("No Bert Adam")

    # load fine-tuned learner
    if tuned_learner:
        print('Loading pretrained learner: ', tuned_learner)
        learn.load(tuned_learner)

    # Uncomment to graph learning rate plot
    # learn.lr_find()
    # learn.recorder.plot(skip_end=15)

    # set lr (discriminative learning rates)
    if div: layers = div
    lrs = lr if not discr else learn.lr_range(slice(lr / lrm**(layers), lr))

    results = [['epoch', 'lr', 'f1', 'val_loss', 'train_loss', 'train_losses']]

    if do_train:
        for epoch in range(epochs):
            if freez:
                lay = (layers // (epochs - 1)) * epoch * -1
                if lay == 0:
                    print('Freeze')
                    learn.freeze()
                elif lay == layers:
                    print('unfreeze')
                    learn.unfreeze()
                else:
                    print('freeze2')
                    learn.freeze_to(lay)
                print('Freezing layers ', lay, ' off ', layers)

            # Fit Learner - eg train model
            if one_cycle: learn.fit_one_cycle(1, lrs, moms=(0.8, 0.7))
            else: learn.fit(1, lrs)

            results.append([
                epoch,
                lrs,
                learn.recorder.metrics[0][0],
                learn.recorder.val_losses[0],
                np.array(learn.recorder.losses).mean(),
                learn.recorder.losses,
            ])

            if save:
                m_path = learn.save(f"{lang}_{epoch}_model", return_path=True)
                print(f'Saved model to {m_path}')
    if save: learn.export(f'{lang}.pkl')

    if do_eval:
        res = learn.validate(test_dl, metrics=metrics)
        met_res = [f'{m.__name__}: {r}' for m, r in zip(metrics, res[1:])]
        print(f'Validation on TEST SET:\nloss {res[0]}, {met_res}')
        results.append(['val', '-', res[1], res[0], '-', '-'])

    with open(log_dir / (name + '.csv'), 'a') as resultFile:
        wr = csv.writer(resultFile)
        wr.writerows(results)
Exemple #17
0
def main():
    global model_to_save
    global experiment
    global rabbit
    rabbit = MyRabbit(args)
    if rabbit.model_params.dont_limit_num_uniq_tokens:
        raise NotImplementedError()
    if rabbit.model_params.frame_as_qa: raise NotImplementedError
    if rabbit.run_params.drop_val_loss_calc: raise NotImplementedError
    if rabbit.run_params.use_softrank_influence and not rabbit.run_params.freeze_all_but_last_for_influence:
        raise NotImplementedError
    if rabbit.train_params.weight_influence: raise NotImplementedError
    experiment = Experiment(rabbit.train_params + rabbit.model_params +
                            rabbit.run_params)
    print('Model name:', experiment.model_name)
    use_pretrained_doc_encoder = rabbit.model_params.use_pretrained_doc_encoder
    use_pointwise_loss = rabbit.train_params.use_pointwise_loss
    query_token_embed_len = rabbit.model_params.query_token_embed_len
    document_token_embed_len = rabbit.model_params.document_token_embed_len
    _names = []
    if not rabbit.model_params.dont_include_titles:
        _names.append('with_titles')
    if rabbit.train_params.num_doc_tokens_to_consider != -1:
        _names.append('num_doc_toks_' +
                      str(rabbit.train_params.num_doc_tokens_to_consider))
    if not rabbit.run_params.just_caches:
        if rabbit.model_params.dont_include_titles:
            document_lookup = read_cache(name('./doc_lookup.json', _names),
                                         get_robust_documents)
        else:
            document_lookup = read_cache(name('./doc_lookup.json', _names),
                                         get_robust_documents_with_titles)
    num_doc_tokens_to_consider = rabbit.train_params.num_doc_tokens_to_consider
    document_title_to_id = read_cache(
        './document_title_to_id.json',
        lambda: create_id_lookup(document_lookup.keys()))
    with open('./caches/106756_most_common_doc.json', 'r') as fh:
        doc_token_set = set(json.load(fh))
        tokenizer = Tokenizer()
        tokenized = set(
            sum(
                tokenizer.process_all(list(
                    get_robust_eval_queries().values())), []))
        doc_token_set = doc_token_set.union(tokenized)
    use_bow_model = not any([
        rabbit.model_params[attr] for attr in
        ['use_doc_out', 'use_cnn', 'use_lstm', 'use_pretrained_doc_encoder']
    ])
    use_bow_model = use_bow_model and not rabbit.model_params.dont_use_bow
    if use_bow_model:
        documents, document_token_lookup = read_cache(
            name(f'./docs_fs_tokens_limit_uniq_toks_qrels_and_106756.pkl',
                 _names),
            lambda: prepare_fs(document_lookup,
                               document_title_to_id,
                               num_tokens=num_doc_tokens_to_consider,
                               token_set=doc_token_set))
        if rabbit.model_params.keep_top_uniq_terms is not None:
            documents = [
                dict(
                    nlargest(rabbit.model_params.keep_top_uniq_terms,
                             _.to_pairs(doc), itemgetter(1)))
                for doc in documents
            ]
    else:
        documents, document_token_lookup = read_cache(
            name(
                f'./parsed_docs_{num_doc_tokens_to_consider}_tokens_limit_uniq_toks_qrels_and_106756.json',
                _names), lambda: prepare(document_lookup,
                                         document_title_to_id,
                                         num_tokens=num_doc_tokens_to_consider,
                                         token_set=doc_token_set))
    if not rabbit.run_params.just_caches:
        train_query_lookup = read_cache('./robust_train_queries.json',
                                        get_robust_train_queries)
        train_query_name_to_id = read_cache(
            './train_query_name_to_id.json',
            lambda: create_id_lookup(train_query_lookup.keys()))
    train_queries, query_token_lookup = read_cache(
        './parsed_robust_queries_dict.json',
        lambda: prepare(train_query_lookup,
                        train_query_name_to_id,
                        token_lookup=document_token_lookup,
                        token_set=doc_token_set,
                        drop_if_any_unk=True))
    query_tok_to_doc_tok = {
        idx: document_token_lookup.get(query_token)
        or document_token_lookup['<unk>']
        for query_token, idx in query_token_lookup.items()
    }
    names = [RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set]]
    if rabbit.train_params.use_pointwise_loss or not rabbit.run_params.just_caches:
        train_data = read_cache(
            name('./robust_train_query_results_tokens_qrels_and_106756.json',
                 names), lambda: read_query_result(
                     train_query_name_to_id,
                     document_title_to_id,
                     train_queries,
                     path='./indri/query_result' + RANKER_NAME_TO_SUFFIX[
                         rabbit.train_params.ranking_set]))
    else:
        train_data = []
    q_embed_len = rabbit.model_params.query_token_embed_len
    doc_embed_len = rabbit.model_params.document_token_embed_len
    if rabbit.model_params.append_difference or rabbit.model_params.append_hadamard:
        assert q_embed_len == doc_embed_len, 'Must use same size doc and query embeds when appending diff or hadamard'
    if q_embed_len == doc_embed_len:
        glove_lookup = get_glove_lookup(
            embedding_dim=q_embed_len,
            use_large_embed=rabbit.model_params.use_large_embed,
            use_word2vec=rabbit.model_params.use_word2vec)
        q_glove_lookup = glove_lookup
        doc_glove_lookup = glove_lookup
    else:
        q_glove_lookup = get_glove_lookup(
            embedding_dim=q_embed_len,
            use_large_embed=rabbit.model_params.use_large_embed,
            use_word2vec=rabbit.model_params.use_word2vec)
        doc_glove_lookup = get_glove_lookup(
            embedding_dim=doc_embed_len,
            use_large_embed=rabbit.model_params.use_large_embed,
            use_word2vec=rabbit.model_params.use_word2vec)
    num_query_tokens = len(query_token_lookup)
    num_doc_tokens = len(document_token_lookup)
    doc_encoder = None
    if use_pretrained_doc_encoder or rabbit.model_params.use_doc_out:
        doc_encoder, document_token_embeds = get_doc_encoder_and_embeddings(
            document_token_lookup, rabbit.model_params.only_use_last_out)
        if rabbit.model_params.use_glove:
            query_token_embeds_init = init_embedding(q_glove_lookup,
                                                     query_token_lookup,
                                                     num_query_tokens,
                                                     query_token_embed_len)
        else:
            query_token_embeds_init = from_doc_to_query_embeds(
                document_token_embeds, document_token_lookup,
                query_token_lookup)
        if not rabbit.train_params.dont_freeze_pretrained_doc_encoder:
            dont_update(doc_encoder)
        if rabbit.model_params.use_doc_out:
            doc_encoder = None
    else:
        document_token_embeds = init_embedding(doc_glove_lookup,
                                               document_token_lookup,
                                               num_doc_tokens,
                                               document_token_embed_len)
        if rabbit.model_params.use_single_word_embed_set:
            query_token_embeds_init = document_token_embeds
        else:
            query_token_embeds_init = init_embedding(q_glove_lookup,
                                                     query_token_lookup,
                                                     num_query_tokens,
                                                     query_token_embed_len)
    if not rabbit.train_params.dont_freeze_word_embeds:
        dont_update(document_token_embeds)
        dont_update(query_token_embeds_init)
    else:
        do_update(document_token_embeds)
        do_update(query_token_embeds_init)
    if rabbit.train_params.add_rel_score:
        query_token_embeds, additive = get_additive_regularized_embeds(
            query_token_embeds_init)
        rel_score = RelScore(query_token_embeds, document_token_embeds,
                             rabbit.model_params, rabbit.train_params)
    else:
        query_token_embeds = query_token_embeds_init
        additive = None
        rel_score = None
    eval_query_lookup = get_robust_eval_queries()
    eval_query_name_document_title_rels = get_robust_rels()
    test_query_names = []
    val_query_names = []
    for query_name in eval_query_lookup:
        if len(val_query_names) >= 50: test_query_names.append(query_name)
        else: val_query_names.append(query_name)
    test_query_name_document_title_rels = _.pick(
        eval_query_name_document_title_rels, test_query_names)
    test_query_lookup = _.pick(eval_query_lookup, test_query_names)
    test_query_name_to_id = create_id_lookup(test_query_lookup.keys())
    test_queries, __ = prepare(test_query_lookup,
                               test_query_name_to_id,
                               token_lookup=query_token_lookup)
    eval_ranking_candidates = read_query_test_rankings(
        './indri/query_result_test' +
        RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set])
    test_candidates_data = read_query_result(
        test_query_name_to_id,
        document_title_to_id,
        dict(zip(range(len(test_queries)), test_queries)),
        path='./indri/query_result_test' +
        RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set])
    test_ranking_candidates = process_raw_candidates(test_query_name_to_id,
                                                     test_queries,
                                                     document_title_to_id,
                                                     test_query_names,
                                                     eval_ranking_candidates)
    test_data = process_rels(test_query_name_document_title_rels,
                             document_title_to_id, test_query_name_to_id,
                             test_queries)
    val_query_name_document_title_rels = _.pick(
        eval_query_name_document_title_rels, val_query_names)
    val_query_lookup = _.pick(eval_query_lookup, val_query_names)
    val_query_name_to_id = create_id_lookup(val_query_lookup.keys())
    val_queries, __ = prepare(val_query_lookup,
                              val_query_name_to_id,
                              token_lookup=query_token_lookup)
    val_candidates_data = read_query_result(
        val_query_name_to_id,
        document_title_to_id,
        dict(zip(range(len(val_queries)), val_queries)),
        path='./indri/query_result_test' +
        RANKER_NAME_TO_SUFFIX[rabbit.train_params.ranking_set])
    val_ranking_candidates = process_raw_candidates(val_query_name_to_id,
                                                    val_queries,
                                                    document_title_to_id,
                                                    val_query_names,
                                                    eval_ranking_candidates)
    val_data = process_rels(val_query_name_document_title_rels,
                            document_title_to_id, val_query_name_to_id,
                            val_queries)
    train_normalized_score_lookup = read_cache(
        name('./train_normalized_score_lookup.pkl', names),
        lambda: get_normalized_score_lookup(train_data))
    test_normalized_score_lookup = get_normalized_score_lookup(
        test_candidates_data)
    val_normalized_score_lookup = get_normalized_score_lookup(
        val_candidates_data)
    if use_pointwise_loss:
        normalized_train_data = read_cache(
            name('./normalized_train_query_data_qrels_and_106756.json', names),
            lambda: normalize_scores_query_wise(train_data))
        collate_fn = lambda samples: collate_query_samples(
            samples,
            use_bow_model=use_bow_model,
            use_dense=rabbit.model_params.use_dense)
        train_dl = build_query_dataloader(
            documents,
            normalized_train_data,
            rabbit.train_params,
            rabbit.model_params,
            cache=name('train_ranking_qrels_and_106756.json', names),
            limit=10,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=train_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=False)
        test_dl = build_query_dataloader(
            documents,
            test_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=test_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        val_dl = build_query_dataloader(
            documents,
            val_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=val_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        model = PointwiseScorer(query_token_embeds, document_token_embeds,
                                doc_encoder, rabbit.model_params,
                                rabbit.train_params)
    else:
        if rabbit.train_params.use_noise_aware_loss:
            ranker_query_str_to_rankings = get_ranker_query_str_to_rankings(
                train_query_name_to_id,
                document_title_to_id,
                train_queries,
                limit=rabbit.train_params.num_snorkel_train_queries)
            query_names = reduce(
                lambda acc, query_to_ranking: acc.intersection(
                    set(query_to_ranking.keys()))
                if len(acc) != 0 else set(query_to_ranking.keys()),
                ranker_query_str_to_rankings.values(), set())
            all_ranked_lists_by_ranker = _.map_values(
                ranker_query_str_to_rankings, lambda query_to_ranking:
                [query_to_ranking[query] for query in query_names])
            ranker_query_str_to_pairwise_bins = get_ranker_query_str_to_pairwise_bins(
                train_query_name_to_id,
                document_title_to_id,
                train_queries,
                limit=rabbit.train_params.num_train_queries)
            snorkeller = Snorkeller(ranker_query_str_to_pairwise_bins)
            snorkeller.train(all_ranked_lists_by_ranker)
            calc_marginals = snorkeller.calc_marginals
        else:
            calc_marginals = None
        collate_fn = lambda samples: collate_query_pairwise_samples(
            samples,
            use_bow_model=use_bow_model,
            calc_marginals=calc_marginals,
            use_dense=rabbit.model_params.use_dense)
        if rabbit.run_params.load_influences:
            try:
                with open(rabbit.run_params.influences_path) as fh:
                    pairs_to_flip = defaultdict(set)
                    for pair, influence in json.load(fh):
                        if rabbit.train_params.use_pointwise_loss:
                            condition = True
                        else:
                            condition = influence < rabbit.train_params.influence_thresh
                        if condition:
                            query = tuple(pair[1])
                            pairs_to_flip[query].add(tuple(pair[0]))
            except FileNotFoundError:
                pairs_to_flip = None
        else:
            pairs_to_flip = None
        train_dl = build_query_pairwise_dataloader(
            documents,
            train_data,
            rabbit.train_params,
            rabbit.model_params,
            pairs_to_flip=pairs_to_flip,
            cache=name('train_ranking_qrels_and_106756.json', names),
            limit=10,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=train_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=False)
        test_dl = build_query_pairwise_dataloader(
            documents,
            test_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=test_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        val_dl = build_query_pairwise_dataloader(
            documents,
            val_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=val_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True)
        val_rel_dl = build_query_pairwise_dataloader(
            documents,
            val_data,
            rabbit.train_params,
            rabbit.model_params,
            query_tok_to_doc_tok=query_tok_to_doc_tok,
            normalized_score_lookup=val_normalized_score_lookup,
            use_bow_model=use_bow_model,
            collate_fn=collate_fn,
            is_test=True,
            rel_vs_irrel=True,
            candidates=val_ranking_candidates,
            num_to_rank=rabbit.run_params.num_to_rank)
        model = PairwiseScorer(query_token_embeds,
                               document_token_embeds,
                               doc_encoder,
                               rabbit.model_params,
                               rabbit.train_params,
                               use_bow_model=use_bow_model)
    train_ranking_dataset = RankingDataset(
        documents,
        train_dl.dataset.rankings,
        rabbit.train_params,
        rabbit.model_params,
        rabbit.run_params,
        query_tok_to_doc_tok=query_tok_to_doc_tok,
        normalized_score_lookup=train_normalized_score_lookup,
        use_bow_model=use_bow_model,
        use_dense=rabbit.model_params.use_dense)
    test_ranking_dataset = RankingDataset(
        documents,
        test_ranking_candidates,
        rabbit.train_params,
        rabbit.model_params,
        rabbit.run_params,
        relevant=test_dl.dataset.rankings,
        query_tok_to_doc_tok=query_tok_to_doc_tok,
        cheat=rabbit.run_params.cheat,
        normalized_score_lookup=test_normalized_score_lookup,
        use_bow_model=use_bow_model,
        use_dense=rabbit.model_params.use_dense)
    val_ranking_dataset = RankingDataset(
        documents,
        val_ranking_candidates,
        rabbit.train_params,
        rabbit.model_params,
        rabbit.run_params,
        relevant=val_dl.dataset.rankings,
        query_tok_to_doc_tok=query_tok_to_doc_tok,
        cheat=rabbit.run_params.cheat,
        normalized_score_lookup=val_normalized_score_lookup,
        use_bow_model=use_bow_model,
        use_dense=rabbit.model_params.use_dense)
    if rabbit.train_params.memorize_test:
        train_dl = test_dl
        train_ranking_dataset = test_ranking_dataset
    model_data = DataBunch(train_dl,
                           val_rel_dl,
                           test_dl,
                           collate_fn=collate_fn,
                           device=torch.device('cuda') if
                           torch.cuda.is_available() else torch.device('cpu'))
    multi_objective_model = MultiObjective(model, rabbit.train_params,
                                           rel_score, additive)
    model_to_save = multi_objective_model
    if rabbit.train_params.memorize_test:
        try:
            del train_data
        except:
            pass
    if not rabbit.run_params.just_caches:
        del document_lookup
        del train_query_lookup
    del query_token_lookup
    del document_token_lookup
    del train_queries
    try:
        del glove_lookup
    except UnboundLocalError:
        del q_glove_lookup
        del doc_glove_lookup
    if rabbit.run_params.load_model:
        try:
            multi_objective_model.load_state_dict(
                torch.load(rabbit.run_params.load_path))
        except RuntimeError:
            dp = nn.DataParallel(multi_objective_model)
            dp.load_state_dict(torch.load(rabbit.run_params.load_path))
            multi_objective_model = dp.module
    else:
        train_model(multi_objective_model, model_data, train_ranking_dataset,
                    val_ranking_dataset, test_ranking_dataset,
                    rabbit.train_params, rabbit.model_params,
                    rabbit.run_params, experiment)
    if rabbit.train_params.fine_tune_on_val:
        fine_tune_model_data = DataBunch(
            val_rel_dl,
            val_rel_dl,
            test_dl,
            collate_fn=collate_fn,
            device=torch.device('cuda')
            if torch.cuda.is_available() else torch.device('cpu'))
        train_model(multi_objective_model,
                    fine_tune_model_data,
                    val_ranking_dataset,
                    val_ranking_dataset,
                    test_ranking_dataset,
                    rabbit.train_params,
                    rabbit.model_params,
                    rabbit.run_params,
                    experiment,
                    load_path=rabbit.run_params.load_path)
    multi_objective_model.eval()
    device = model_data.device
    gpu_multi_objective_model = multi_objective_model.to(device)
    if rabbit.run_params.calc_influence:
        if rabbit.run_params.freeze_all_but_last_for_influence:
            last_layer_idx = _.find_last_index(
                multi_objective_model.model.pointwise_scorer.layers,
                lambda layer: isinstance(layer, nn.Linear))
            to_last_layer = lambda x: gpu_multi_objective_model(
                *x, to_idx=last_layer_idx)
            last_layer = gpu_multi_objective_model.model.pointwise_scorer.layers[
                last_layer_idx]
            diff_wrt = [p for p in last_layer.parameters() if p.requires_grad]
        else:
            diff_wrt = None
        test_hvps = calc_test_hvps(
            multi_objective_model.loss,
            gpu_multi_objective_model,
            DeviceDataLoader(train_dl, device, collate_fn=collate_fn),
            val_rel_dl,
            rabbit.run_params,
            diff_wrt=diff_wrt,
            show_progress=True,
            use_softrank_influence=rabbit.run_params.use_softrank_influence)
        influences = []
        if rabbit.train_params.use_pointwise_loss:
            num_real_samples = len(train_dl.dataset)
        else:
            num_real_samples = train_dl.dataset._num_pos_pairs
        if rabbit.run_params.freeze_all_but_last_for_influence:
            _sampler = SequentialSamplerWithLimit(train_dl.dataset,
                                                  num_real_samples)
            _batch_sampler = BatchSampler(_sampler,
                                          rabbit.train_params.batch_size,
                                          False)
            _dl = DataLoader(train_dl.dataset,
                             batch_sampler=_batch_sampler,
                             collate_fn=collate_fn)
            sequential_train_dl = DeviceDataLoader(_dl,
                                                   device,
                                                   collate_fn=collate_fn)
            influences = calc_dataset_influence(gpu_multi_objective_model,
                                                to_last_layer,
                                                sequential_train_dl,
                                                test_hvps,
                                                sum_p=True).tolist()
        else:
            for i in progressbar(range(num_real_samples)):
                train_sample = train_dl.dataset[i]
                x, labels = to_device(collate_fn([train_sample]), device)
                device_train_sample = (x, labels.squeeze())
                influences.append(
                    calc_influence(multi_objective_model.loss,
                                   gpu_multi_objective_model,
                                   device_train_sample,
                                   test_hvps,
                                   diff_wrt=diff_wrt).sum().tolist())
        with open(rabbit.run_params.influences_path, 'w+') as fh:
            json.dump([[train_dl.dataset[idx][1], influence]
                       for idx, influence in enumerate(influences)], fh)
Exemple #18
0
def main(config, args):
    if torch.cuda.is_available():
        cudnn.benchmark = True
        print('Using CUDA')
    else:
        print('**** CUDA is not available ****')

    pprint.pprint(config)

    if args.exp is None:
        if not os.path.exists('./config/old_configs/' + config.exp_name):
            os.makedirs('./config/old_configs/' + config.exp_name)
        shutil.copy2(
            './config/config.py',
            './config/old_configs/{}/config.py'.format(config.exp_name))

    if not os.path.exists('./model_weights/' + config.exp_name):
        os.makedirs('./model_weights/' + config.exp_name)
    if not os.path.exists('./logs/' + config.exp_name):
        os.makedirs('./logs/' + config.exp_name)

    data_df = pd.read_csv(config.DATA_CSV_PATH)
    if os.path.exists('/content/data'):
        print('On Colab')
        data_df['Id'] = data_df['Id'].apply(lambda x: '/content' + x[1:])

    if config.dataclass is not None:
        data_df = data_df[data_df['Type(Full/Head/Unclean/Bad)'] ==
                          config.dataclass].reset_index(drop=True)
    split_train_mask = (data_df['Fold'] != 'Fold{}'.format(args.foldidx))
    train_df = data_df[split_train_mask
                       & (data_df['Split'] == 'Train')].reset_index(drop=True)
    valid_df = data_df[(~split_train_mask)
                       & (data_df['Split'] == 'Train')].reset_index(drop=True)
    test_df = data_df[data_df['Split'] == 'Test'].reset_index(drop=True)
    maintest_df = data_df[data_df['Split'] == 'MainTest'].reset_index(
        drop=True)

    print("Training with valid fold: ", args.foldidx)
    print(valid_df.head())

    if config.pseudo_path is not None:
        assert not (config.add_val_pseudo and config.add_val_orig)
        if config.add_val_pseudo:
            pseudo_df = pd.concat((valid_df, test_df, maintest_df))
        else:
            pseudo_df = pd.concat((test_df, maintest_df))
        pseudo_df['Id'] = pseudo_df['Id'] + '_pseudo'
        if config.add_val_orig:
            pseudo_df = pd.concat((pseudo_df, valid_df))
        train_df = pd.concat((train_df, pseudo_df)).reset_index(drop=True)

    train_tfms = get_train_tfms(config)
    print(train_tfms)
    if config.debug and config.reduce_dataset:
        if config.pseudo_path is not None:
            train_df = pd.concat(
                (train_df[:10], pseudo_df[:10])).reset_index(drop=True)
        else:
            train_df = train_df[:10]
        valid_df = valid_df[:10]


#     DatasetClass = KBPDataset2D if psutil.virtual_memory().total < 20e9 else KBPDataset2DStack
    DatasetClass = KBPDataset2D
    train_ds = DatasetClass(config, train_df, transform=train_tfms)
    valid_ds = DatasetClass(config, valid_df, valid=True)

    # valid_dl = DataLoader(valid_ds, batch_size=128, shuffle=False, num_workers=config.num_workers)

    criterion = KBPLoss(config)

    Net = getattr(model_list, config.model_name)

    net = Net(config=config).to(config.device)
    print(net)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("Number of parameters: ", count_parameters(net))

    if config.load_model_ckpt is not None:
        print('Loading model from {}'.format(
            config.load_model_ckpt.format(args.foldidx)))
        net.load_state_dict(
            torch.load(config.load_model_ckpt.format(args.foldidx))['model'])

    gpu = setup_distrib(config.gpu)
    opt = config.optimizer
    mom = config.mom
    alpha = config.alpha
    eps = config.eps

    if opt == 'adam':
        opt_func = partial(optim.Adam,
                           betas=(mom, alpha),
                           eps=eps,
                           amsgrad=config.amsgrad)
    elif opt == 'adamw':
        opt_func = partial(optim.AdamW, betas=(mom, alpha), eps=eps)
    elif opt == 'radam':
        opt_func = partial(RAdam,
                           betas=(mom, alpha),
                           eps=eps,
                           degenerated_to_sgd=config.radam_degenerated_to_sgd)
    elif opt == 'sgd':
        opt_func = partial(optim.SGD, momentum=mom, nesterov=config.nesterov)
    elif opt == 'ranger':
        opt_func = partial(Ranger, betas=(mom, alpha), eps=eps)
    else:
        raise ValueError("Optimizer not recognized")
    print(opt_func)

    data = DataBunch.create(train_ds,
                            valid_ds,
                            bs=config.batch_size,
                            num_workers=config.num_workers)

    # metrics = [dose_score, dvh_score, pred_mean, target_mean]
    metrics = [dose_score2D, dvh_score2D, pred_mean2D, target_mean2D]
    evalbatchaccum = EvalBatchAccumulator(config,
                                          target_bs=128,
                                          num_metrics=len(metrics))
    learn = (Learner(evalbatchaccum,
                     data,
                     net,
                     wd=config.weight_decay,
                     opt_func=opt_func,
                     bn_wd=False,
                     true_wd=True,
                     loss_func=criterion,
                     metrics=metrics,
                     path='./model_weights/{}/'.format(config.exp_name)))
    if config.fp16:
        print('Training with mixed precision...')
        learn = learn.to_fp16(dynamic=True)
    else:
        print('Full precision training...')
    if gpu is None: learn.to_parallel()
    elif num_distrib() > 1: learn.to_distributed(gpu)
    if config.mixup: learn = learn.mixup(alpha=config.mixup, stack_y=False)
    print("Learn path: ", learn.path)
    best_save_cb = SaveBestModel(learn,
                                 config,
                                 outfile='_fold{}'.format(args.foldidx))
    logger_cb = CSVLogger(learn)
    logger_cb.path = Path(
        str(logger_cb.path).replace('model_weights/', 'logs/').replace(
            '.csv', '_fold{}.csv'.format(args.foldidx)))
    callbacks = [best_save_cb, logger_cb]

    if config.teachers is not None:
        package = 'config.old_configs.{}.config'.format(config.teachers)
        teacherconfig = importlib.import_module(package).config
        teachers = []
        for fold in range(5):
            teacher = getattr(model_list, teacherconfig.model_name)
            teacher = teacher(teacherconfig)
            model_ckpt = './model_weights/{}/models/best_dose_fold{}.pth'.format(
                teacherconfig.exp_name, fold)
            print("Loading teacher {} encoder from {}".format(
                fold, model_ckpt))
            teacher.load_state_dict(torch.load(model_ckpt)['model'])
            teacher.to(config.device)
            teacher.eval()
            for param in teacher.parameters():
                param.requires_grad = False
            teachers.append(teacher)
    else:
        teachers = None

    if config.wandb:
        wandb.init(project=config.wandb_project, name=config.exp_name)
        wandb_cb = WandbCallback(learn)
        callbacks.append(wandb_cb)

    print(learn.loss_func.config.loss_dict)
    print(learn.opt_func)
    print("Weight decay: ", learn.wd)

    learn.fit_one_cycle(config.epochs,
                        config.lr,
                        callbacks=callbacks,
                        div_factor=config.div_factor,
                        pct_start=config.pct_start,
                        final_div=config.final_div,
                        teachers=teachers)

    best_str = "Best valid loss: {}, dose score: {}, dvh score: {}".format(
        best_save_cb.best_loss, best_save_cb.best_dose.item(),
        best_save_cb.best_dvh.item())
    print(best_str)
    f = open(
        "./logs/{}/bestmetrics_fold{}.txt".format(config.exp_name,
                                                  args.foldidx), "a")
    f.write(best_str)
    f.close()
Exemple #19
0
import config
from dataset import Dataset
from model import WideResNet22
from fastai.train import Learner
from fastai.metrics import accuracy
from torch.nn import functional as f
from fastai.basic_data import DataBunch

cifar10 = Dataset()
# cifar10.download_dataset()
train_dataloader, valid_dataloader = cifar10.get_dataloader()
model = WideResNet22(3, 10)

data = DataBunch(train_dataloader, valid_dataloader)
learner = Learner(data, model, loss_func=f.cross_entropy, metrics=[accuracy])
learner.clip = 0.1
learner.fit_one_cycle(config.EPOCHS, config.LEARNING_RATE, wd=1e-4)
val_ds = MoleculeDataset(val_mol_ids, gb_mol_sc, gb_mol_atom, gb_mol_bond,
                         gb_mol_struct, gb_mol_angle_in, gb_mol_angle_out,
                         gb_mol_graph_dist)
test_ds = MoleculeDataset(test_mol_ids, test_gb_mol_sc, gb_mol_atom,
                          gb_mol_bond, gb_mol_struct, gb_mol_angle_in,
                          gb_mol_angle_out, gb_mol_graph_dist)

train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, num_workers=8)
val_dl = DataLoader(val_ds, args.batch_size, num_workers=8)
test_dl = DeviceDataLoader.create(test_ds,
                                  args.batch_size,
                                  num_workers=8,
                                  collate_fn=partial(collate_parallel_fn,
                                                     test=True))

db = DataBunch(train_dl, val_dl, collate_fn=collate_parallel_fn)
db.test_dl = test_dl

# set up model
set_seed(100)
d_model = args.d_model
enn_args = dict(layers=3 * [d_model], dropout=3 * [0.0], layer_norm=True)
ann_args = dict(layers=1 * [d_model],
                dropout=1 * [0.0],
                layer_norm=True,
                out_act=nn.Tanh())
model = Transformer(C.N_ATOM_FEATURES,
                    C.N_BOND_FEATURES,
                    C.N_SC_EDGE_FEATURES,
                    C.N_SC_MOL_FEATURES,
                    N=args.N,
Exemple #21
0
torch.cuda.manual_seed(12345)

from torch.utils.data import DataLoader, RandomSampler
from dataset import dataset

PATH = '/home/fernand/raven/neutral_pth/'
train = dataset(PATH, 'train')
valid = dataset(PATH, 'val')
test = dataset(PATH, 'test')

trainloader = DataLoader(train, batch_size=32, shuffle=True, num_workers=6)
#trainloader = DataLoader(train, batch_size=32, sampler=RandomSampler(train, replacement=True, num_samples=3200), shuffle=False, num_workers=6)
validloader = DataLoader(valid, batch_size=32, shuffle=False, num_workers=6)
testloader = DataLoader(test, batch_size=32, shuffle=False, num_workers=6)

from functools import partial
from fastai.basic_data import DataBunch
from fastai.basic_train import Learner
from torch.optim import Adam

from loss import loss_fn, Accuracy
from wren import WReN

db = DataBunch(train_dl=trainloader, valid_dl=validloader, test_dl=testloader)
wren = WReN()
opt = partial(Adam, betas=(0.9, 0.999), eps=1e-8)
learn = Learner(data=db, model=wren, opt_func=opt, loss_func=loss_fn, metrics=[Accuracy()])
#from fastai.train import to_fp16
#learn = to_fp16(learn)
learn.fit(20, lr=1e-4, wd=0.0)
Exemple #22
0
    #frames = enhance_frames('./train')

#loading data labels
    labels = load_data_labels('train.txt') #returns dataframe

#reading in frames into a dataset in python
    train = imageDataset(labels, './trainbright') #gives me a dataset that can be converted into dataloader
    complete_train = pd.DataFrame(columns = ['image', 'speed'])
    for i in range(5000): #too many open files, doing first 5000 frames as subset of whole dataset
       complete_train = complete_train.append({'image' : train.__getitem__(i)[0], 'speed' : train.__getitem__(i)[1]}, ignore_index=True)
    #print(complete_train) yay works

#splitting 'complete train (5000 rows)' into train and validation
    dataset_size = len(complete_train)
    validation_split = 0.3
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    dataload_train = DataLoader(complete_train, batch_size=1, shuffle=False, sampler=train_sampler, batch_sampler=None, num_workers=4)
    dataload_valid = DataLoader(complete_train, batch_size=1, shuffle=False, sampler=valid_sampler, batch_sampler=None, num_workers=4)

#Put dataloader into databunch
    databunch_for_model = DataBunch(dataload_train, dataload_valid, no_check=True) #no_check problem (sanity_check error)

#Run cnn_learner using the databunch we have
    learner = cnn_learner(databunch_for_model, models.resnet18)
Exemple #23
0
def build_learner(params, project_dir, pindex=0, comm_file=None, queues=None):
    """
    Builds a fastai `Learner` object containing the model and data specified by
    `params`. It is configured to run on GPU `device_id`. Assumes it is GPU
    `pindex` of `world_size` total GPUs. In case more than one GPU is being
    used, a file named `comm_file` is used to communicate between processes.
    """
    # For user friendly error messages, check these parameters exist.
    check_params(params, [
        'cpu',
        'data.batch_size',
        'data.dir',
        'data.epoch_size',
        'data.max_length',
        'data.max_val_size',
        'data.src',
        'data.tgt',
        'data.vocab',
        'decoder.embedding_dim',
        'decoder.embedding_dropout',
        'decoder.prediction_dropout',
        'encoder.embedding_dim',
        'encoder.embedding_dropout',
        'network.bias',
        'network.block_sizes',
        'network.division_factor',
        'network.dropout',
        'network.efficient',
        'network.growth_rate',
        'network.kernel_size',
    ])

    model_name = params['model_name']

    # Try to make the directory for saving models.
    model_dir = os.path.join(project_dir, 'model', model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Configure GPU/CPU device settings.
    cpu = params['cpu']
    gpu_ids = params['gpu_ids'] if not cpu else []
    world_size = len(gpu_ids) if len(gpu_ids) > 0 else 1
    distributed = world_size > 1
    if gpu_ids:
        device_id = gpu_ids[pindex]
        device = torch.device(device_id)
        torch.cuda.set_device(device_id)
    else:
        device_id = None
        device = torch.device('cpu')

    # If distributed, initialize inter-process communication using shared file.
    if distributed:
        torch.distributed.init_process_group(backend='nccl',
                                             world_size=world_size,
                                             rank=pindex,
                                             init_method=f'file://{comm_file}')

    # Load vocabulary.
    vocab_path = os.path.join(params['data']['dir'], params['data']['vocab'])
    vocab = VocabData(vocab_path)

    # Load data.
    src_l = params['data']['src']
    tgt_l = params['data']['tgt']
    loader = PervasiveDataLoader(os.path.join(params['data']['dir'],
                                              f'{src_l}.h5'),
                                 os.path.join(params['data']['dir'],
                                              f'{tgt_l}.h5'),
                                 vocab,
                                 vocab,
                                 params['data']['batch_size'] // world_size,
                                 params['data']['max_length'],
                                 epoch_size=params['data']['epoch_size'],
                                 max_val_size=params['data']['max_val_size'],
                                 distributed=distributed,
                                 world_size=world_size,
                                 pindex=pindex)

    # Define neural network.
    # Max length is 1 more than setting to account for BOS.
    if params['network']['type'] == 'pervasive-embeddings':
        model = PervasiveEmbedding(
            params['network']['block_sizes'], vocab.bos, loader.max_length,
            loader.max_length, loader.datasets['train'].arrays[0].shape[2],
            params['encoder']['embedding_dim'],
            params['encoder']['embedding_dropout'],
            params['network']['dropout'],
            params['decoder']['prediction_dropout'],
            params['network']['division_factor'],
            params['network']['growth_rate'], params['network']['bias'],
            params['network']['efficient'])
        # Rescale loss by 100 for easier display in training output.
        loss_func = scaled_mse_loss
    elif params['network']['type'] == 'pervasive-downsample':
        model = PervasiveDownsample(
            params['network']['block_sizes'], vocab.bos, loader.max_length,
            loader.max_length, params['encoder']['embedding_dim'],
            params['encoder']['embedding_dropout'],
            params['network']['dropout'],
            params['decoder']['prediction_dropout'],
            params['network']['division_factor'],
            params['network']['growth_rate'], params['network']['bias'],
            params['network']['efficient'], params['network']['kernel_size'])
        # Rescale loss by 100 for easier display in training output.
        loss_func = F.cross_entropy
    elif params['network']['type'] == 'pervasive-bert':
        model = PervasiveBert(
            params['network']['block_sizes'], vocab.bos, loader.max_length,
            loader.max_length, params['encoder']['embedding_dim'],
            params['encoder']['embedding_dropout'],
            params['network']['dropout'],
            params['decoder']['prediction_dropout'],
            params['network']['division_factor'],
            params['network']['growth_rate'], params['network']['bias'],
            params['network']['efficient'], params['network']['kernel_size'])
        loss_func = F.cross_entropy
    elif params['network']['type'] == 'pervasive-original':
        model = PervasiveOriginal(
            params['network']['block_sizes'], len(vocab), vocab.bos,
            loader.max_length, loader.max_length,
            params['encoder']['embedding_dim'],
            params['encoder']['embedding_dropout'],
            params['network']['dropout'],
            params['decoder']['prediction_dropout'],
            params['network']['division_factor'],
            params['network']['growth_rate'], params['network']['bias'],
            params['network']['efficient'], params['network']['kernel_size'])
        loss_func = F.cross_entropy
    elif params['network']['type'] == 'pervasive':
        model = Pervasive(
            params['network']['block_sizes'], len(vocab), vocab.bos,
            loader.max_length, loader.max_length,
            params['encoder']['initial_emb_dim'],
            params['encoder']['embedding_dim'],
            params['encoder']['embedding_dropout'],
            params['network']['dropout'],
            params['decoder']['prediction_dropout'],
            params['network']['division_factor'],
            params['network']['growth_rate'], params['network']['bias'],
            params['network']['efficient'], params['network']['kernel_size'])
        loss_func = F.cross_entropy

    model.init_weights()
    if device_id is not None:
        if not torch.cuda.is_available():
            raise ValueError(
                'Request to train on GPU {device_id}, but not GPU found.')
        model.cuda(device_id)
        if distributed:
            model = DistributedDataParallel(model, device_ids=[device_id])
    data = DataBunch(loader.loaders['train'],
                     loader.loaders['valid'],
                     loader.loaders['valid'],
                     device=device)

    # Create Learner with Adam optimizer.
    learn = Learner(data, model, loss_func=loss_func, model_dir=model_dir)
    AdamP = partial(torch.optim.Adam,
                    betas=(params['optim']['beta1'], params['optim']['beta2']))
    learn.opt_func = AdamP
    learn.wd = params['optim']['wd']

    return (learn, loader.loaders['train'].src_vocab,
            loader.loaders['train'].tgt_vocab)
Exemple #24
0
def train(train_dataset: torch.utils.data.Dataset,
          test_dataset: torch.utils.data.Dataset,
          training_config: dict = train_config,
          global_config: dict = global_config):
    """
    Template training routine. Takes a training and a test dataset wrapped
    as torch.utils.data.Dataset type and two corresponding generic
    configs for both gobal path settings and training settings.
    Returns the fitted fastai.train.Learner object which can be
    used to assess the resulting metrics and error curves etc.
    """

    for path in global_config.values():
        create_dirs(path)

    # wrap datasets with Dataloader classes
    train_loader = torch.utils.data.DataLoader(
        train_dataset, **train_config["DATA_LOADER_CONFIG"])
    test_loader = torch.utils.data.DataLoader(
        test_dataset, **train_config["DATA_LOADER_CONFIG"])
    databunch = DataBunch(train_loader, test_loader)

    # instantiate model and learner
    if training_config["WEIGHTS"] is None:
        model = training_config["MODEL"](**training_config["MODEL_CONFIG"])
    else:
        model = load_model(training_config["MODEL"],
                           training_config["MODEL_CONFIG"],
                           training_config["WEIGHTS"],
                           training_config["DEVICE"])

    learner = Learner(databunch,
                      model,
                      metrics=train_config["METRICS"],
                      path=global_config["ROOT_PATH"],
                      model_dir=global_config["WEIGHT_DIR"],
                      loss_func=train_config["LOSS"])

    # model name & paths
    name = "_".join([train_config["DATE"], train_config["SESSION_NAME"]])
    modelpath = os.path.join(global_config["WEIGHT_DIR"], name)

    if train_config["MIXED_PRECISION"]:
        learner.to_fp16()

    learner.save(modelpath)

    torch.backends.cudnn.benchmark = True

    cbs = [
        SaveModelCallback(learner),
        LearnerTensorboardWriter(
            learner,
            Path(os.path.join(global_config["LOG_DIR"]), "tensorboardx"),
            name),
        TerminateOnNaNCallback()
    ]

    # perform training iteration
    try:
        if train_config["ONE_CYCLE"]:
            learner.fit_one_cycle(train_config["EPOCHS"],
                                  max_lr=train_config["LR"],
                                  callbacks=cbs)
        else:
            learner.fit(train_config["EPOCHS"],
                        lr=train_config["LR"],
                        callbacks=cbs)
    # save model files
    except KeyboardInterrupt:
        learner.save(modelpath)
        raise KeyboardInterrupt

    learner.save(modelpath)
    val_loss = min(learner.recorder.val_losses)
    val_metrics = learner.recorder.metrics

    # log using the logging tool
    logger = log.Log(train_config, run_name=train_config['SESSION_NAME'])
    logger.log_metric('Validation Loss', val_loss)
    logger.log.metrics(val_metrics)
    logger.end_run()

    #write csv log file
    log_content = train_config.copy()
    log_content["VAL_LOSS"] = val_loss
    log_content["VAL_METRICS"] = val_metrics
    log_path = os.path.join(global_config["LOG_DIR"], train_config["LOGFILE"])
    write_log(log_path, log_content)

    return learner, log_content, name
Exemple #25
0
v_chain = torch.utils.data.ChainDataset(v_batch_sets)
t_final = time() - start
print(t_final)
dlc = DLCollator(preproc=proc, apply_ops=False)
print("Creating dataloaders")
start = time()
t_data = DLDataLoader(t_chain,
                      collate_fn=dlc.gdf_col,
                      pin_memory=False,
                      num_workers=0)
v_data = DLDataLoader(v_chain,
                      collate_fn=dlc.gdf_col,
                      pin_memory=False,
                      num_workers=0)

databunch = DataBunch(t_data, v_data, collate_fn=dlc.gdf_col, device="cuda")
t_final = time() - start
print(t_final)
print("Creating model")
start = time()
model = TabularModel(emb_szs=embeddings,
                     n_cont=len(cont_names),
                     out_sz=2,
                     layers=[512, 256])
learn = Learner(databunch, model, metrics=[accuracy])
learn.loss_func = torch.nn.CrossEntropyLoss()
t_final = time() - start
print(t_final)
print("Finding learning rate")
start = time()
learn.lr_find()
Exemple #26
0
def main():
    model = StegNet(10, 6)
    print("Created Model")

    if args.train:
        data_train = ImageLoader(args.datapath + '/train', args.num_train,
                                 args.fourierSeed, args.size, args.bs)
        data_val = ImageLoader(args.datapath + '/val', args.num_val,
                               args.fourierSeed, args.size, args.bs)
        data = DataBunch(data_train, data_val)

        print("Loaded DataSets")

        if args.model is not None:
            model.load_state_dict(torch.load(args.model))
            print("Loaded pretrained model")

        loss_fn = mse

        learn = Learner(data,
                        model,
                        loss_func=loss_fn,
                        metrics=[mse_cov, mse_hidden])

        print("training")
        fit_one_cycle(learn, args.epochs, 1e-2)

        torch.save(learn.model.state_dict(), "model.pth")
        print("model saved")

    else:
        path = input(
            "Enter path of the model: ") if args.model is None else args.model
        model.load_state_dict(torch.load(args.model))
        model.eval()

        if args.encode:
            f_paths = [
                args.datapath + '/cover/' + f
                for f in os.listdir(args.datapath + '/cover')
            ]
            try:
                os.mkdir(args.datapath + '/encoded')
            except OSError:
                pass
            fourier_func = partial(encrypt, seed=args.fourierSeed)
            encode_partial = partial(encode,
                                     model=model.encoder,
                                     size=args.size,
                                     fourier_func=fourier_func)
            parallel(encode_partial, f_paths)

        else:
            f_paths = [
                args.datapath + '/encoded/' + f
                for f in os.listdir(args.datapath + '/encoded')
            ]
            try:
                os.mkdir(args.datapath + '/decoded')
            except OSError:
                pass
            fourier_func = partial(decrypt, seed=args.fourierSeed)
            decode_partial = partial(decode,
                                     model=model.decoder,
                                     size=args.size,
                                     fourier_func=fourier_func)
            parallel(decode_partial, f_paths)