Пример #1
0
 def reset_dropout_to_eval(m):
     if type(m) == nn.Dropout:
         p = dropout_ps[m]
         logger.info("reseting dropout into eval mode (%s) p=%d",
                     str(m), p)
         m.p = p
         m.eval()
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.loss_func, *loss_func_params = config.loss_func.split(':')

        try:
            self.loss_func_params = [
                float(param) for param in loss_func_params
            ]
        except ValueError as exp:
            raise ValueError(f"invalid loss function parameters") from exp

        logger.info('using loss function: %s', self.loss_func)

        if config.label_weights:
            self.label_weights = torch.tensor(config.label_weights)
            logger.info('using label weights: %s', self.label_weights)
        else:
            self.label_weights = None

        self.multi_label = config.multi_label
        self.soft_label = config.soft_label
        self.classifier = None
        self.classifier = eval(config.head_cls)(config)
        self.head = eval(config.head_cls)(config)
Пример #3
0
def load_model(config, model_config):
    # config_cls, _, _, head_class = MODEL_CLASSES[config.model_type]

    # model_cls = get_model_cls(config)

    model_cls = AutoHeadlessConfig.model_class(model_config)
    head_class = AutoHeadlessConfig.head_class(model_config)

    logger.debug("loaded model_config: %s", model_config)

    if config.no_pretrain:
        logger.warning("Using non pretrained model!")
        model = model_cls(config=model_config)
    else:
        model = model_cls.from_pretrained(
            config.model_path,
            config=model_config
        )

        if config.reinit_layers:
            logger.info(f"reinitializing layers... {config.reinit_layers}")
            model.reinit_layers(config.reinit_layers)

        if config.reinit_pooler:
            logger.info(f"reinitializing pooler...")
            model.reinit_pooler()

    return model
Пример #4
0
    def save_model(self, model_path):
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        logger.info("Saving model to %s", model_path)

        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`

        model_to_save = (self.model.module
                         if hasattr(self.model, "module") else self.model)
        model_to_save.save_pretrained(model_path)
        self.tokenizer.save_pretrained(model_path)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.config.as_dict(),
                   os.path.join(model_path, "training_config.bin"))
Пример #5
0
    def save_checkpoint(self):
        output_dir = os.path.join(self.config.output_model_path,
                                  "checkpoint-{}".format(self.global_step))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        model_to_save = (
            self.model.module if hasattr(self.model, "module") else self.model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)

        torch.save(self.config.as_dict(),
                   os.path.join(output_dir, "training_self.config.bin"))
        logger.info("Saving model checkpoint to %s", output_dir)

        torch.save(self.optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
        torch.save(self.scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))
        logger.info("Saving optimizer and scheduler states to %s", output_dir)
def active_learn(config,
                 model_config,
                 tokenizer,
                 results,
                 label_names,
                 test_df,
                 full_pool_df,
                 backtrans_pool_dfs,
                 get_dataloader_func,
                 run_configs,
                 active_learning_iters=10,
                 dropout_iters=20,
                 balance=False):
    test_dataloader = get_dataloader_func(test_df, bs=config.eval_bs)

    for run_config in run_configs:
        method, dropout, backtrans_langs, cluster_size = run_config
        run_name = method.__name__
        if dropout:
            run_name += '_dropout'
        run_name = '_'.join([run_name, *backtrans_langs, f"c{cluster_size}"])

        util.set_seed(config)

        model = tu.load_model(config, model_config)
        model.to(config.device)

        # remove initial seed from pool
        train_df, pool_df = train_test_split(
            full_pool_df,
            train_size=config.active_learn_seed_size,
            random_state=config.seed)

        logger.info("RUN CONFIG: %s (pool size: %d)", run_name,
                    pool_df.shape[0])

        experiment = Experiment(config,
                                model,
                                tokenizer,
                                label_names=label_names,
                                run_name=run_name,
                                results=results)

        cur_iter = 0

        extra_log = {'iter': cur_iter, 'pool': pool_df.shape[0]}
        experiment.evaluate('test', test_dataloader, extra_log=extra_log)

        while pool_df.shape[0] > 0:
            train_dataloader = get_dataloader_func(train_df,
                                                   bs=config.train_bs,
                                                   balance=balance)

            # DON'T SHUFFLE THE POOL!
            dataloader_pool = get_dataloader_func(pool_df,
                                                  bs=config.eval_bs,
                                                  shuffle=False)

            logger.info(
                "=================== Remaining %d (%s) ================",
                pool_df.shape[0], run_config)
            logger.info(
                "Evaluating: training set size: %d | pool set size: %d",
                train_df.shape[0], pool_df.shape[0])

            global_step, tr_loss = experiment.train(train_dataloader)

            extra_log = {'iter': cur_iter, 'pool': pool_df.shape[0]}

            _, _, preds = experiment.evaluate('pool',
                                              dataloader_pool,
                                              extra_log=extra_log)
            experiment.evaluate('test', test_dataloader, extra_log=extra_log)

            if method != af.random_conf:
                if dropout:
                    for i in range(dropout_iters):
                        torch.manual_seed(i)

                        _, _, preds_i = experiment.evaluate('pool_dropout',
                                                            dataloader_pool,
                                                            mc_dropout=True,
                                                            skip_cb=True)
                        preds_i = torch.from_numpy(preds_i)
                        probs_i = F.softmax(preds_i, dim=1)

                        if i == 0:
                            probs = probs_i
                        else:
                            probs.add_(probs_i)
                    probs.div_(dropout_iters)
                else:
                    preds = torch.from_numpy(preds)
                    probs = F.softmax(preds, dim=1)
            else:
                preds = torch.from_numpy(preds)

                # only need the shape
                probs = preds

            scores = method(probs)
            _, topk_indices = torch.topk(
                scores,
                min(cluster_size * config.active_learn_step_size,
                    scores.shape[0]))

            if cluster_size > 1:
                topk_preds = preds[topk_indices]
                n_clusters = min(config.active_learn_step_size,
                                 scores.shape[0])
                kmeans = KMeans(n_clusters=n_clusters).fit(topk_preds.numpy())
                _, unique_indices = np.unique(kmeans.labels_,
                                              return_index=True)
                topk_indices = topk_indices[unique_indices]
                # assert(topk_indices.shape[0] == n_clusters)
                logger.debug("top_k: %s", topk_indices.shape)

            logger.debug("%s %s", scores.shape, pool_df.shape)

            assert (scores.shape[0] == pool_df.shape[0])

            uncertain_rows = pool_df.iloc[topk_indices]
            train_df = train_df.append(uncertain_rows, ignore_index=True)

            for backtrans_lang in backtrans_langs:
                backtrans_pool_df = backtrans_pool_dfs[backtrans_lang]
                backtrans_uncertain_rows = backtrans_pool_df[
                    backtrans_pool_df.id.isin(uncertain_rows.id)]
                train_df = train_df.append(backtrans_uncertain_rows,
                                           ignore_index=True)

            pool_df = pool_df.drop(pool_df.index[topk_indices])
            cur_iter += 1

        logger.debug(
            "Pool exhausted, stopping active learning loop (%d remaining)",
            pool_df.shape[0])

        results = experiment.results
    return results
Пример #7
0
    def interpret(self, dataloader, df, label_names=None):

        dataset = dataloader.dataset
        sampler = SequentialSampler(dataset)

        # We need a sequential dataloader with bs=1
        dataloader = DataLoader(dataset,
                                sampler=sampler,
                                batch_size=1,
                                num_workers=4)

        logger.info("***** Running interpretation *****")
        logger.info("  Num examples = %d", len(dataset))

        # preds = None
        losses = None
        pred_labels = []

        self.model.eval()

        for batch in tqdm(dataloader, desc="Interpretation"):

            with torch.no_grad():
                inputs = self.__inputs_from_batch(batch)
                # if config.model_type != "distilbert":
                #    inputs["token_type_ids"] = (
                #        batch[2] if config.model_type in [
                #            "bert", "xlnet", "albert"] else None
                #    )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
                outputs = self.model(**inputs)
                batch_loss, logits = outputs[:2]

                if self.config.n_gpu > 1:
                    batch_loss = batch_loss.mean(
                    )  # mean() to average on multi-gpu parallel training

                batch_loss = batch_loss.detach().cpu().view(1)

                pred_label_ids = self.logits_to_label_ids(
                    logits.detach().cpu())
                pred_label_id = pred_label_ids[0]
                if label_names:
                    pred_labels.append(label_names[pred_label_id])
                else:
                    pred_labels.append(pred_label_id)

            if losses is None:
                # preds = logits.detach().cpu().numpy()
                losses = batch_loss
            else:
                # preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                losses = torch.cat((losses, batch_loss), dim=0)

        top_values, top_indices = torch.topk(losses, 100)
        top_indices = top_indices.numpy()
        top_pred_labels = [pred_labels[top_index] for top_index in top_indices]

        top_df = df.iloc[top_indices]
        top_df = top_df.assign(loss=top_values.numpy(),
                               pred_label=top_pred_labels)

        return top_df
Пример #8
0
 def set_dropout_to_train(m):
     if type(m) == nn.Dropout:
         logger.info("setting dropout into train mode (%s)", str(m))
         logger.info("setting dropout into train mode (%s)", str(m))
         m.p = 0.5
         m.train()
Пример #9
0
    def evaluate(self,
                 eval_name,
                 dataloader,
                 mc_dropout=False,
                 skip_cb=False,
                 pred_label_ids_func=None,
                 backtrans=True,
                 extra_log={}):
        dropout_ps = {}

        def set_dropout_to_train(m):
            if type(m) == nn.Dropout:
                logger.info("setting dropout into train mode (%s)", str(m))
                logger.info("setting dropout into train mode (%s)", str(m))
                m.p = 0.5
                m.train()

        def reset_dropout_to_eval(m):
            if type(m) == nn.Dropout:
                p = dropout_ps[m]
                logger.info("reseting dropout into eval mode (%s) p=%d",
                            str(m), p)
                m.p = p
                m.eval()

        # Eval!
        logger.info("***** Running evaluation %s*****", eval_name)
        logger.info("  Num examples = %d", len(dataloader.dataset))
        logger.info("  Batch size = %d", self.config.eval_bs)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        true_label_ids = None

        self.model.eval()

        if mc_dropout:
            self.model.apply(set_dropout_to_train)

        for batch in tqdm(dataloader, desc="Evaluating"):

            with torch.no_grad():
                inputs = self.__inputs_from_batch(batch)
                labels = inputs['labels']

                outputs = self.model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                true_label_ids = labels.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                true_label_ids = np.append(true_label_ids,
                                           labels.detach().cpu().numpy(),
                                           axis=0)

        if mc_dropout:
            self.model.apply(reset_dropout_to_eval)

        eval_loss = eval_loss / nb_eval_steps

        if self.config.test_backtrans_langs and backtrans:
            logger.info('Using test augmentation...')
            groups = np.split(preds, len(self.config.test_backtrans_langs) + 1)
            #preds = sum(groups)

            preds = np.mean(groups, axis=0)
            #preds = np.maximum.reduce(groups)
            true_label_ids = true_label_ids[:preds.shape[0]]

        label_idxs = list(range(len(self.label_names)))

        if self.config.soft_label:
            true_label_ids = np.argmax(true_label_ids, axis=1)

        pred_label_ids = self.logits_to_label_ids(preds)

        if pred_label_ids_func:
            pred_label_ids = pred_label_ids_func(pred_label_ids)

        # print(out_label_ids)
        # print(max_preds)
        # print(out_label_ids.shape, max_preds.shape)

        result = {
            'acc':
            accuracy_score(true_label_ids, pred_label_ids),
            'macro_f1':
            f1_score(true_label_ids, pred_label_ids, average='macro'),
            'micro_f1':
            f1_score(true_label_ids, pred_label_ids, average='micro'),
            'prfs':
            precision_recall_fscore_support(true_label_ids,
                                            pred_label_ids,
                                            labels=label_idxs)
        }

        if not self.config.multi_label:
            result['cm'] = confusion_matrix(true_label_ids,
                                            pred_label_ids).ravel()

        if self.config.num_labels == 2:
            result['macro_auc'] = roc_auc_score(true_label_ids,
                                                pred_label_ids,
                                                average='macro')
            result['avg_precision'] = average_precision_score(
                true_label_ids, pred_label_ids)

        logger.info("***** Eval results {} *****".format(eval_name))

        try:
            logger.info(
                "\n %s",
                classification_report(
                    true_label_ids,
                    pred_label_ids,
                    labels=label_idxs,
                    target_names=self.label_names,
                ))

            result['report'] = classification_report(
                true_label_ids,
                pred_label_ids,
                labels=label_idxs,
                target_names=self.label_names,
                output_dict=True)
        except ValueError as e:
            print(e)
            pass

        logger.info("\n Accuracy = %f", result['acc'])

        if self.config.num_labels == 2:
            logger.info("\n MacroAUC = %f", result['macro_auc'])
            logger.info("\n AUPRC = %f", result['avg_precision'])

        logger.info("***** Done evaluation *****")

        if not skip_cb:
            self.after_eval_cb(eval_name, result, pred_label_ids, preds,
                               extra_log)
        return result, pred_label_ids, preds
Пример #10
0
    def train(self,
              train_dataloader,
              valid_dataloader=None,
              test_dataloader=None,
              should_continue=False):
        """ Train the model """
        tb_writer = SummaryWriter()

        train_epochs = self.config.train_epochs

        if self.config.max_steps > 0:
            train_steps = self.config.max_steps
            train_epochs = self.config.max_steps // (
                len(train_dataloader) // self.config.grad_acc_steps) + 1
        else:
            train_steps = len(
                train_dataloader) // self.config.grad_acc_steps * train_epochs

        if self.total_samples and should_continue:
            steps_total = self.total_samples // self.config.train_bs // self.config.grad_acc_steps * train_epochs
        else:
            steps_total = train_steps

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.config.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0
            },
        ]

        self.optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.config.lr,
            eps=self.config.adam_eps,
        )

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=steps_total)

        # self.scheduler = get_constant_schedule(self.optimizer)

        if should_continue and self.global_step > 0:
            logger.info("loading saved optimizer and scheduler states")
            assert (self.optimizer_state_dict)
            assert (self.scheduler_state_dict)
            self.optimizer.load_state_dict(self.optimizer_state_dict)
            self.scheduler.load_state_dict(self.scheduler_state_dict)
        else:
            logger.info("Using fresh optimizer and scheduler")

        if self.config.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level=self.config.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.config.n_gpu > 1 and not isinstance(self.model,
                                                    torch.nn.DataParallel):
            self.model = torch.nn.DataParallel(self.model)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d (%d)", len(train_dataloader.dataset),
                    len(train_dataloader))
        logger.info("  Num Epochs = %d", train_epochs)
        logger.info("  Batch size = %d", self.config.train_bs)
        logger.info("  Learning rate = %e", self.config.lr)
        logger.info("  Loss label weights = %s",
                    self.config.loss_label_weights)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.config.train_bs * self.config.grad_acc_steps)
        logger.info("  Gradient Accumulation steps = %d",
                    self.config.grad_acc_steps)
        logger.info("  Total optimization steps = %d", train_steps)

        if not should_continue:
            self.global_step = 0

        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        # # Check if continuing training from a checkpoint
        # if os.path.exists(self.config.model_path):
        #     if self.config.should_continue:
        #         step_str = self.config.model_path.split("-")[-1].split("/")[0]

        #         if step_str:
        #             # set self.global_step to gobal_step of last saved checkpoint from model path
        #             self.global_step = int(step_str)
        #             epochs_trained = self.global_step // (len(train_dataloader) //
        #                                                   self.config.grad_acc_steps)
        #             steps_trained_in_current_epoch = self.global_step % (
        #                 len(train_dataloader) // self.config.grad_acc_steps)

        #             logger.info(
        #                 "  Continuing training from checkpoint, will skip to saved self.global_step")
        #             logger.info(
        #                 "  Continuing training from epoch %d", epochs_trained)
        #             logger.info(
        #                 "  Continuing training from global step %d", self.global_step)
        #             logger.info("  Will skip the first %d steps in the first epoch",
        #                         steps_trained_in_current_epoch)

        train_loss = 0.0
        self.model.zero_grad()
        train_iterator = trange(
            epochs_trained,
            int(train_epochs),
            desc="Epoch",
        )
        util.set_seed(self.config)  # Added here for reproductibility

        self.model.train()

        if self.config.train_head_only:
            for param in self.model.roberta.embeddings.parameters():
                param.requires_grad = False
            logger.info("Training only head")
            # for param in self.model.__getattr__(self.config.model_type).roberta.parameters():
            #     param.requires_grad = False

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                self.model.train()

                inputs = self.__inputs_from_batch(batch)
                outputs = self.model(**inputs)

                # model outputs are always tuple in transformers (see doc)
                loss = outputs[0]

                if self.config.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if self.config.grad_acc_steps > 1:
                    loss = loss / self.config.grad_acc_steps

                if self.config.fp16:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                batch_loss = loss.item()
                train_loss += batch_loss

                if (step + 1) % self.config.grad_acc_steps == 0:
                    if self.config.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.config.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.config.max_grad_norm)

                    self.optimizer.step()
                    self.scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    self.global_step += 1

                    if self.config.logging_steps > 0 and self.global_step % self.config.logging_steps == 0:
                        logs = {}
                        if valid_dataloader:
                            result_valid, * \
                                _ = self.evaluate(
                                    'valid', valid_dataloader, backtrans=(test_dataloader == None))
                            logs.update({
                                f"valid_{k}": v
                                for k, v in result_valid.items()
                            })

                        if test_dataloader:
                            test_dataloader = test_dataloader if isinstance(
                                test_dataloader, dict) else {
                                    'test': test_dataloader
                                }
                            for eval_name, dataloader_or_tuple in test_dataloader.items(
                            ):
                                if isinstance(dataloader_or_tuple, tuple):
                                    dataloader, kwargs = dataloader_or_tuple
                                else:
                                    dataloader = dataloader_or_tuple
                                    kwargs = {}

                                result_test, * \
                                    _ = self.evaluate(
                                        eval_name, dataloader, **kwargs)
                                logs.update({
                                    f"{eval_name}_{k}": v
                                    for k, v in result_test.items()
                                })

                        learning_rate_scalar = self.scheduler.get_last_lr()[0]
                        logger.info("Learning rate: %f (at step %d)",
                                    learning_rate_scalar, step)
                        logs["learning_rate"] = learning_rate_scalar
                        logs["train_loss"] = train_loss

                        self.after_logging(logs)

                        logger.info("Batch loss: %f", batch_loss)

                        # for key, value in logs.items():
                        #     tb_writer.add_scalar(key, value, self.global_step)

                    if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0:
                        # Save model checkpoint
                        self.save_checkpoint()

                if self.config.max_steps > 0 and self.global_step > self.config.max_steps:
                    epoch_iterator.close()
                    break
            if self.config.max_steps > 0 and self.global_step > self.config.max_steps:
                train_iterator.close()
                break

        if self.config.train_head_only:
            logger.info("Training only head")
            # for param in self.model.__getattr__(self.config.model_type).parameters():
            #     param.requires_grad = True

            for param in self.model.roberta.embeddings.parameters():
                param.requires_grad = False

        tb_writer.close()
        self.optimizer_state_dict = self.optimizer.state_dict()
        self.scheduler_state_dict = self.scheduler.state_dict()

        avg_train_loss = train_loss / self.global_step

        logger.info("Learning rate now: %s", self.scheduler.get_last_lr())
        logger.info("***** Done training *****")
        return self.global_step, avg_train_loss
Пример #11
0
 def __add_special_tokens(self):
     bins = [f'<l{b}>' for b in self.all_bins]
     logger.info(f'Adding special tokens {bins}')
     self.tokenizer.add_special_tokens({'additional_special_tokens': bins})
Пример #12
0
def get_dataloader(config, tokenizer, text_values, label_ids, bs, text_pair_values=None, shuffle=True, balance=False, enumerated=False, extra_features=None):

    # for t in text_values:
    #     print(t)
    #     x = tokenizer.tokenize(t)
    #     print(x)

    # for t, l in zip(text_values, label_ids):
    #     x = tokenizer.tokenize(t)[255:]
    #     if x: print(l, x)

    # logger.info('Original: %s', sents_list[0])
    # logger.info('Tokenized: %s', tokenizer.tokenize(sents_list[0]))
    # logger.info('Token IDs: %s', tokenizer.convert_tokens_to_ids(
    #     tokenizer.tokenize(sents_list[0])))

    input_ids = [tokenizer.encode(t,
                                  text_pair=text_pair_values[i] if text_pair_values is not None else None,
                                  add_special_tokens=True,
                                  max_length=config.max_seq_len,
                                  truncation=True,
                                  pad_to_max_length=True) for i, t in enumerate(text_values)]

    logger.debug(tokenizer.decode(input_ids[0]))

    attention_masks = build_attention_masks(input_ids)

    input_ids_t = torch.tensor(input_ids)
    label_ids_t = torch.tensor(label_ids, dtype=(
        torch.float32 if config.multi_label or config.soft_label else torch.int64))
    attention_masks_t = torch.tensor(attention_masks)

    tensors = [input_ids_t, attention_masks_t, label_ids_t]

    if extra_features is not None:
        extra_features_t = torch.tensor(extra_features, dtype=torch.float32)
        tensors.append(extra_features_t)

    dataset = TensorDataset(*tensors)

    if enumerated:
        dataset = EnumeratedDataset(dataset)

    if config.local_rank != -1:
        sampler = DistributedSampler(dataset)
    elif balance:
        if config.soft_label:
            raise ValueError('balancing for soft labels in not implemented')
        if not config.multi_label:
            label_weights = 1.0 / np.bincount(label_ids)
        else:
            # label values are boolean
            label_weights = 1.0 / np.sum(label_ids, axis=0)
        logger.info('label weights %s', label_weights / label_weights.sum())
        if not config.multi_label:
            weights = [label_weights[l] for l in label_ids]
        else:
            weights = (label_ids_t * torch.tensor(label_weights)).mean(dim=1)
        sampler = WeightedRandomSampler(weights, int(1*len(weights)))
    elif shuffle:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)

    logger.info(f"Using sampler: {sampler}")
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=bs, num_workers=4)

    return dataloader