Ejemplo n.º 1
0
    def test_performance(self):
        """Test slicing performance with 2 corresponding slice tasks that
        represent roughly <10% of the data."""

        dataloaders = []
        for df, split in [(self.df_train, "train"), (self.df_valid, "valid")]:
            dataloader = create_dataloader(df, split)
            dataloaders.append(dataloader)

        base_task = create_task("task", module_suffixes=["A", "B"])

        # Apply SFs
        slicing_functions = [f, g]  # low-coverage slices
        slice_names = [sf.name for sf in slicing_functions]
        applier = PandasSFApplier(slicing_functions)
        S_train = applier.apply(self.df_train, progress_bar=False)
        S_valid = applier.apply(self.df_valid, progress_bar=False)

        # Add slice labels
        add_slice_labels(dataloaders[0], base_task, S_train)
        add_slice_labels(dataloaders[1], base_task, S_valid)

        # Convert to slice tasks
        tasks = convert_to_slice_tasks(base_task, slice_names)
        model = MultitaskClassifier(tasks=tasks)

        # Train
        # NOTE: Needs more epochs to convergence with more heads
        trainer = Trainer(lr=0.001, n_epochs=65, progress_bar=False)
        trainer.fit(model, dataloaders)
        scores = model.score(dataloaders)

        # Confirm reasonably high slice scores
        # Check train scores
        self.assertGreater(scores["task/TestData/train/f1"], 0.9)
        self.assertGreater(scores["task_slice:f_pred/TestData/train/f1"], 0.9)
        self.assertGreater(scores["task_slice:f_ind/TestData/train/f1"], 0.9)
        self.assertGreater(scores["task_slice:g_pred/TestData/train/f1"], 0.9)
        self.assertGreater(scores["task_slice:g_ind/TestData/train/f1"], 0.9)
        self.assertGreater(scores["task_slice:base_pred/TestData/train/f1"],
                           0.9)
        self.assertEqual(scores["task_slice:base_ind/TestData/train/f1"], 1.0)

        # Check valid scores
        self.assertGreater(scores["task/TestData/valid/f1"], 0.9)
        self.assertGreater(scores["task_slice:f_pred/TestData/valid/f1"], 0.9)
        self.assertGreater(scores["task_slice:f_ind/TestData/valid/f1"], 0.9)
        self.assertGreater(scores["task_slice:g_pred/TestData/valid/f1"], 0.9)
        self.assertGreater(scores["task_slice:g_ind/TestData/valid/f1"], 0.9)
        self.assertGreater(scores["task_slice:base_pred/TestData/valid/f1"],
                           0.9)
        # base_ind is trivial: all labels are positive
        self.assertEqual(scores["task_slice:base_ind/TestData/valid/f1"], 1.0)
Ejemplo n.º 2
0
    def test_convergence(self):
        """Test slicing convergence with 1 slice task that represents ~25% of
        the data."""

        dataloaders = []
        for df, split in [(self.df_train, "train"), (self.df_valid, "valid")]:
            dataloader = create_dataloader(df, split)
            dataloaders.append(dataloader)

        base_task = create_task("task", module_suffixes=["A", "B"])

        # Apply SFs
        slicing_functions = [h]  # high coverage slice
        slice_names = [sf.name for sf in slicing_functions]
        applier = PandasSFApplier(slicing_functions)
        S_train = applier.apply(self.df_train, progress_bar=False)
        S_valid = applier.apply(self.df_valid, progress_bar=False)

        self.assertEqual(S_train.shape, (self.N_TRAIN, ))
        self.assertEqual(S_valid.shape, (self.N_VALID, ))
        self.assertIn("h", S_train.dtype.names)

        # Add slice labels
        add_slice_labels(dataloaders[0], base_task, S_train)
        add_slice_labels(dataloaders[1], base_task, S_valid)

        # Convert to slice tasks
        tasks = convert_to_slice_tasks(base_task, slice_names)
        model = MultitaskClassifier(tasks=tasks)

        # Train
        trainer = Trainer(lr=0.001, n_epochs=50, progress_bar=False)
        trainer.fit(model, dataloaders)
        scores = model.score(dataloaders)

        # Confirm near perfect scores
        self.assertGreater(scores["task/TestData/valid/accuracy"], 0.94)
        self.assertGreater(scores["task_slice:h_pred/TestData/valid/accuracy"],
                           0.94)
        self.assertGreater(scores["task_slice:h_ind/TestData/valid/f1"], 0.94)

        # Calculate/check train/val loss
        train_dataset = dataloaders[0].dataset
        train_loss_output = model.calculate_loss(train_dataset.X_dict,
                                                 train_dataset.Y_dict)
        train_loss = train_loss_output[0]["task"].item()
        self.assertLess(train_loss, 0.1)

        val_dataset = dataloaders[1].dataset
        val_loss_output = model.calculate_loss(val_dataset.X_dict,
                                               val_dataset.Y_dict)
        val_loss = val_loss_output[0]["task"].item()
        self.assertLess(val_loss, 0.1)
    def test_checkpointer_init(self):
        with tempfile.TemporaryDirectory() as temp_dir:
            more_config = {
                "checkpointing": True,
                "checkpointer_config": {
                    "checkpoint_dir": temp_dir
                },
                "log_writer_config": {
                    "log_dir": temp_dir
                },
            }
            trainer = Trainer(**base_config, **more_config, logging=True)
            trainer.fit(model, [dataloaders[0]])
            self.assertIsNotNone(trainer.checkpointer)

            broken_config = {
                "checkpointing": True,
                "checkpointer_config": {
                    "checkpoint_dir": None
                },
                "log_writer_config": {
                    "log_dir": temp_dir
                },
            }
            with self.assertRaises(TypeError):
                trainer = Trainer(**base_config,
                                  **broken_config,
                                  logging=False)
                trainer.fit(model, [dataloaders[0]])
 def test_log_writer_json(self):
     # Addresses issue #1439
     # Confirm that a log file is written to the specified location after training
     run_name = "log.json"
     with tempfile.TemporaryDirectory() as temp_dir:
         log_writer_config = {"log_dir": temp_dir, "run_name": run_name}
         trainer = Trainer(
             **base_config,
             logging=True,
             log_writer="json",
             log_writer_config=log_writer_config,
         )
         trainer.fit(model, [dataloaders[0]])
         log_path = os.path.join(trainer.log_writer.log_dir, run_name)
         with open(log_path, "r") as f:
             log = json.load(f)
         self.assertIn("model/all/train/loss", log)
    def test_convergence(self):
        """Test multitask classifier convergence with two tasks."""

        dataloaders = []

        for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]):
            df = create_data(N_TRAIN, offset)
            dataloader = create_dataloader(df, "train", task_name)
            dataloaders.append(dataloader)

        for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]):
            df = create_data(N_VALID, offset)
            dataloader = create_dataloader(df, "valid", task_name)
            dataloaders.append(dataloader)

        task1 = create_task("task1", module_suffixes=["A", "A"])
        task2 = create_task("task2", module_suffixes=["A", "B"])
        model = MultitaskClassifier(tasks=[task1, task2])

        # Train
        trainer = Trainer(lr=0.001, n_epochs=10, progress_bar=False)
        trainer.fit(model, dataloaders)
        scores = model.score(dataloaders)

        # Confirm near perfect scores on both tasks
        for idx, task_name in enumerate(["task1", "task2"]):
            self.assertGreater(scores[f"{task_name}/TestData/valid/accuracy"], 0.95)

            # Calculate/check train/val loss
            train_dataset = dataloaders[idx].dataset
            train_loss_output = model.calculate_loss(
                train_dataset.X_dict, train_dataset.Y_dict
            )
            train_loss = train_loss_output[0][task_name].item()
            self.assertLess(train_loss, 0.05)

            val_dataset = dataloaders[2 + idx].dataset
            val_loss_output = model.calculate_loss(
                val_dataset.X_dict, val_dataset.Y_dict
            )
            val_loss = val_loss_output[0][task_name].item()
            self.assertLess(val_loss, 0.05)
    def test_trainer_errors(self):
        dataloader = copy.deepcopy(dataloaders[0])

        # No train split
        trainer = Trainer(**base_config)
        dataloader.dataset.split = "valid"
        with self.assertRaisesRegex(ValueError, "Cannot find any dataloaders"):
            trainer.fit(model, [dataloader])

        # Unused split
        trainer = Trainer(**base_config, valid_split="val")
        with self.assertRaisesRegex(ValueError, "Dataloader splits must be"):
            trainer.fit(model, [dataloader])
Ejemplo n.º 7
0
test_dl_slice = slice_model.make_slice_dataloader(test_dl.dataset,
                                                  S_test,
                                                  shuffle=False,
                                                  batch_size=BATCH_SIZE)

# %% [markdown]
# ### Representation learning with slices

# %% [markdown]
# Using Snorkel's [`Trainer`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/classification/snorkel.classification.Trainer.html), we fit our classifier with the training set dataloader.

# %%
from snorkel.classification import Trainer

# For demonstration purposes, we set n_epochs=2
trainer = Trainer(n_epochs=2, lr=1e-4, progress_bar=True)
trainer.fit(slice_model, [train_dl_slice])

# %% [markdown]
# At inference time, the primary task head (`spam_task`) will make all final predictions.
# We'd like to evaluate all the slice heads on the original task head — [`score_slices`](https://snorkel.readthedocs.io/en/v0.9.3/packages/_autosummary/slicing/snorkel.slicing.SliceAwareClassifier.html#snorkel.slicing.SliceAwareClassifier.score_slices) remaps all slice-related labels, denoted `spam_task_slice:{slice_name}_pred`, to be evaluated on the `spam_task`.

# %%
slice_model.score_slices([test_dl_slice], as_dataframe=True)

# %% [markdown]
# *Note: in this toy dataset, we see high variance in slice performance, because our dataset is so small that (i) there are few data points in the train split, giving little signal to learn over, and (ii) there are few data points in the test split, making our evaluation metrics very noisy.
# For a demonstration of data slicing deployed in state-of-the-art models, please see our [SuperGLUE](https://github.com/HazyResearch/snorkel-superglue/tree/master/tutorials) tutorials.*

# %% [markdown]
# ---
# %%
import torchvision.models as models

# initialize pretrained feature extractor
cnn = models.resnet18(pretrained=True)
model = create_model(cnn)

# %% [markdown]
# ### Train and Evaluate Model

# %% {"tags": ["md-exclude-output"]}
from snorkel.classification import Trainer

trainer = Trainer(
    n_epochs=1,  # increase for improved performance
    lr=1e-3,
    checkpointing=True,
    checkpointer_config={"checkpoint_dir": "checkpoint"},
)
trainer.fit(model, [dl_train])

# %%
model.score([dl_valid])

# %% [markdown]
# ## Recap
# We have successfully trained a visual relationship detection model! Using categorical and spatial intuition about how objects in a visual relationship interact with each other, we are able to assign high quality training labels to object pairs in the VRD dataset in a multi-class classification setting.
#
# For more on how Snorkel can be used for visual relationship tasks, please see our [ICCV 2019 paper](https://arxiv.org/abs/1904.11622)!
Ejemplo n.º 9
0
    def test_save_load(self):
        non_base_config = {"n_epochs": 2, "progress_bar": False}
        trainer1 = Trainer(**base_config, lr_scheduler="exponential")
        trainer1.fit(model, [dataloaders[0]])
        trainer2 = Trainer(**non_base_config, lr_scheduler="linear")
        trainer3 = Trainer(**non_base_config, lr_scheduler="linear")

        with tempfile.NamedTemporaryFile() as fd:
            checkpoint_path = fd.name
            trainer1.save(checkpoint_path)
            trainer2.load(checkpoint_path, model=model)
            trainer3.load(checkpoint_path, None)

        self.assertEqual(trainer1.config, trainer2.config)
        self.dict_check(
            trainer1.optimizer.state_dict(), trainer2.optimizer.state_dict()
        )

        # continue training after load
        trainer2.fit(model, [dataloaders[0]])

        # check that an inappropriate model does not load an optimizer state but a trainer config
        self.assertEqual(trainer1.config, trainer3.config)
        self.assertFalse(hasattr(trainer3, "optimizer"))
        trainer3.fit(model, [dataloaders[0]])
 def test_trainer_onetask(self):
     """Train a single-task model"""
     trainer = Trainer(**base_config)
     trainer.fit(model, [dataloaders[0]])
    def test_warmup(self):
        lr_scheduler_config = {"warmup_steps": 1, "warmup_unit": "batches"}
        trainer = Trainer(**base_config,
                          lr_scheduler_config=lr_scheduler_config)
        trainer.fit(model, [dataloaders[0]])
        self.assertEqual(trainer.warmup_steps, 1)

        lr_scheduler_config = {"warmup_steps": 1, "warmup_unit": "epochs"}
        trainer = Trainer(**base_config,
                          lr_scheduler_config=lr_scheduler_config)
        trainer.fit(model, [dataloaders[0]])
        self.assertEqual(trainer.warmup_steps, BATCHES_PER_EPOCH)

        lr_scheduler_config = {"warmup_percentage": 1 / BATCHES_PER_EPOCH}
        trainer = Trainer(**base_config,
                          lr_scheduler_config=lr_scheduler_config)
        trainer.fit(model, [dataloaders[0]])
        self.assertEqual(trainer.warmup_steps, 1)
    def test_scheduler_init(self):
        trainer = Trainer(**base_config, lr_scheduler="constant")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsNone(trainer.lr_scheduler)

        trainer = Trainer(**base_config, lr_scheduler="linear")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.lr_scheduler,
                              optim.lr_scheduler.LambdaLR)

        trainer = Trainer(**base_config, lr_scheduler="exponential")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.lr_scheduler,
                              optim.lr_scheduler.ExponentialLR)

        trainer = Trainer(**base_config, lr_scheduler="step")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.lr_scheduler, optim.lr_scheduler.StepLR)

        with self.assertRaisesRegex(ValueError, "Unrecognized lr scheduler"):
            trainer = Trainer(**base_config, lr_scheduler="foo")
            trainer.fit(model, [dataloaders[0]])
    def test_optimizer_init(self):
        trainer = Trainer(**base_config, optimizer="sgd")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.optimizer, optim.SGD)

        trainer = Trainer(**base_config, optimizer="adam")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.optimizer, optim.Adam)

        trainer = Trainer(**base_config, optimizer="adamax")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.optimizer, optim.Adamax)

        with self.assertRaisesRegex(ValueError, "Unrecognized optimizer"):
            trainer = Trainer(**base_config, optimizer="foo")
            trainer.fit(model, [dataloaders[0]])
    def test_log_writer_init(self):
        with tempfile.TemporaryDirectory() as temp_dir:
            log_writer_config = {"log_dir": temp_dir}
            trainer = Trainer(
                **base_config,
                logging=True,
                log_writer="json",
                log_writer_config=log_writer_config,
            )
            trainer.fit(model, [dataloaders[0]])
            self.assertIsInstance(trainer.log_writer, LogWriter)

            log_writer_config = {"log_dir": temp_dir}
            trainer = Trainer(
                **base_config,
                logging=True,
                log_writer="tensorboard",
                log_writer_config=log_writer_config,
            )
            trainer.fit(model, [dataloaders[0]])
            self.assertIsInstance(trainer.log_writer, TensorBoardWriter)

            log_writer_config = {"log_dir": temp_dir}
            with self.assertRaisesRegex(ValueError, "Unrecognized writer"):
                trainer = Trainer(
                    **base_config,
                    logging=True,
                    log_writer="foo",
                    log_writer_config=log_writer_config,
                )
                trainer.fit(model, [dataloaders[0]])
Ejemplo n.º 15
0
def slicing_evaluation(df_train, df_test, train_model=None):
    if train_model is None:
        train_model = "mlp"

    sfs = [
        SlicingFunction.short_comment, SlicingFunction.ind_keyword,
        SlicingFunction.cmp_re, SlicingFunction.industry_keyword
    ]

    slice_names = [sf.name for sf in sfs]
    scorer = Scorer(metrics=["f1"])

    ft = FT.load(f"{WORK_PATH}/snorkel_flow/sources/fasttext_name_model.bin")

    def get_ftr(text):
        return ft.get_sentence_vector(' '.join(
            [w for w in jieba.lcut(text.strip())]))

    X_train = np.array(list(df_train.text.apply(get_ftr).values))
    X_test = np.array(list(df_test.text.apply(get_ftr).values))
    Y_train = df_train.label.values
    Y_test = df_test.label.values

    if train_model == "lr":
        sklearn_model = LogisticRegression(C=0.001, solver="liblinear")
        sklearn_model.fit(X=X_train, y=Y_train)
        preds_test = sklearn_model.predict(X_test)
        probs_test = preds_to_probs(
            preds_test,
            len([c for c in dir(Polarity) if not c.startswith("__")]))
        print(f"Test set F1: {100 * f1_score(Y_test, preds_test):.1f}%")
        applier = PandasSFApplier(sfs)
        S_test = applier.apply(df_test)
        analysis = scorer.score_slices(S=S_test,
                                       golds=Y_test,
                                       preds=preds_test,
                                       probs=probs_test,
                                       as_dataframe=True)
        return analysis

    if train_model == "mlp":
        # Define model architecture
        bow_dim = X_train.shape[1]
        hidden_dim = bow_dim
        mlp = get_pytorch_mlp(hidden_dim=hidden_dim, num_layers=2)

        # Initialize slice model
        slice_model = SliceAwareClassifier(
            base_architecture=mlp,
            head_dim=hidden_dim,
            slice_names=slice_names,
            scorer=scorer,
        )

        # generate the remaining S matrices with the new set of slicing functions
        applier = PandasSFApplier(sfs)
        S_train = applier.apply(df_train)
        S_test = applier.apply(df_test)

        # add slice labels to an existing dataloader
        BATCH_SIZE = 64

        train_dl = create_dict_dataloader(X_train, Y_train, "train")
        train_dl_slice = slice_model.make_slice_dataloader(
            train_dl.dataset, S_train, shuffle=True, batch_size=BATCH_SIZE)
        test_dl = create_dict_dataloader(X_test, Y_test, "train")
        test_dl_slice = slice_model.make_slice_dataloader(
            test_dl.dataset, S_test, shuffle=False, batch_size=BATCH_SIZE)

        #  fit our classifier with the training set dataloader
        trainer = Trainer(n_epochs=2, lr=1e-4, progress_bar=True)
        trainer.fit(slice_model, [train_dl_slice])

        analysis = slice_model.score_slices([test_dl_slice], as_dataframe=True)
        return analysis
 def test_trainer_twotask(self):
     """Train a model with overlapping modules and flows"""
     multitask_model = MultitaskClassifier(tasks)
     trainer = Trainer(**base_config)
     trainer.fit(multitask_model, dataloaders)
Ejemplo n.º 17
0
        )

        # Add task to list of tasks
        tasks.append(task_object)

# Input list of tasks to MultitaskClassifier object to create model with architecture set for each task
model = MultitaskClassifier(tasks)

# Set out trainer settings - I.e. how the model will train
trainer_config = {
    "progress_bar": True,
    "n_epochs": 2,
    "lr": 0.02,
    "logging": True,
    "log_writer": "json",
    "checkpointing": True,
}

# Create trainer object using above settings
trainer = Trainer(**trainer_config)

# Train model using above settings on the datasets linked
trainer.fit(model, dataloaders)

# Output training stats of model
trainer.log_writer.write_log("output_statistics.json")

# Score model using test set and print
model_scores = model.score(dataloaders)
print(model_scores)
Ejemplo n.º 18
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = range(int(args.num_train_epochs))
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)

    if args.model_type == 'bert-slice-aware' or args.model_type == 'bert-slice-aware-random-slices':
        if args.model_type == 'bert-slice-aware':
            sfs = slicing_functions[args.task_name]
        elif args.model_type == 'bert-slice-aware-random-slices':
            if args.number_random_slices is None or args.size_random_slices is None:
                sfs = random_slicing_functions[args.task_name]
            else:
                sfs = args.sfs
        processor = slicing_processors[args.task_name]()
        examples_train = processor.get_train_examples(args.data_dir)

        snorkel_sf_applier = SFApplier(sfs)

        if os.path.isfile(args.data_dir + "/snorkel_slices_train.pickle"):
            with open(args.data_dir + "/snorkel_slices_train.pickle",
                      "rb") as f:
                logger.info("loaded cached pickle for sliced train.")
                snorkel_slices_train = pickle.load(f)
        else:
            snorkel_slices_train = snorkel_sf_applier.apply(examples_train)
            with open(args.data_dir + "/snorkel_slices_train.pickle",
                      "wb") as f:
                pickle.dump(snorkel_slices_train, f)
                logger.info("dumped pickle with sliced train.")

        snorkel_slices_with_ns = []
        for i, example in enumerate(examples_train):
            for _ in range(len(example.documents)):
                snorkel_slices_with_ns.append(snorkel_slices_train[i])

        snorkel_slices_with_ns_np = np.array(snorkel_slices_with_ns,
                                             dtype=snorkel_slices_train.dtype)

        slice_model = SliceAwareClassifier(
            task_name='labels',
            input_data_key='input_ids',
            base_architecture=model,
            head_dim=768,  #* args.max_seq_length,
            slice_names=[sf.name for sf in sfs])

        X_dict = {
            'input_ids': train_dataset.tensors[0],
            'attention_mask': train_dataset.tensors[1],
            'token_type_ids': train_dataset.tensors[2]
        }
        Y_dict = {'labels': train_dataset.tensors[3]}

        ds = DictDataset(name='labels',
                         split='train',
                         X_dict=X_dict,
                         Y_dict=Y_dict)
        train_dl_slice = slice_model.make_slice_dataloader(
            ds,
            snorkel_slices_with_ns_np,
            shuffle=True,
            batch_size=args.train_batch_size)

        trainer = Trainer(lr=args.learning_rate,
                          n_epochs=int(args.num_train_epochs),
                          l2=args.weight_decay,
                          optimizer="adamax",
                          max_steps=args.max_steps,
                          seed=args.seed)

        trainer.fit(slice_model, [train_dl_slice])
        model = slice_model
    else:
        for _ in train_iterator:
            epoch_iterator = train_dataloader
            for step, batch in enumerate(epoch_iterator):
                model.train()
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3]
                }
                if args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2] if args.model_type in [
                        'bert', 'xlnet'
                    ] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
                if args.model_type == 'bert-mtl':
                    inputs["clf_head"] = 0
                outputs = model(**inputs)
                loss = outputs[
                    0]  # model outputs are always tuple in transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                            results = evaluate(args,
                                               model,
                                               tokenizer,
                                               sample_percentage=0.01)
                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key),
                                                     value, global_step)
                                ex.log_scalar('eval_{}'.format(key), value,
                                              global_step)
                                logger.info('eval_{}'.format(key) + ": " +
                                            str(value) + ", step: " +
                                            str(global_step))
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             args.logging_steps, global_step)
                        ex.log_scalar("lr", scheduler.get_lr()[0], global_step)
                        ex.log_scalar("loss", (tr_loss - logging_loss) /
                                      args.logging_steps, global_step)
                        logging_loss = tr_loss

                    if args.local_rank in [
                            -1, 0
                    ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args.output_dir,
                            'checkpoint-{}'.format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model.module if hasattr(
                            model, 'module'
                        ) else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        torch.save(
                            args, os.path.join(output_dir,
                                               'training_args.bin'))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

                if args.max_steps > 0 and global_step > args.max_steps:
                    break
                    # epoch_iterator.close()
            if args.max_steps > 0 and global_step > args.max_steps:
                break
                # train_iterator.close()

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return model