Beispiel #1
0
 def infer_ctr(self):
     
     cfg = self.cfg
     
     if cfg.dataset == 'codraw':
         self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0_e.pt'))
     elif cfg.dataset == 'iclevr':
         self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0_e.pt'))
     
     dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size)
     dataloader = DataLoader(dataset, 
                             batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
     
     if cfg.dataset == 'codraw':
         dataloader.collate_fn = codraw_dataset.collate_data
     elif cfg.dataset == 'iclevr':
         dataloader.collate_fn = clevr_dataset.collate_data
     
     glove_key = list(dataset.glove_key.keys())
     
     for batch in dataloader:
         rec_out, loss = self.model.train_ctr(batch, -1, -1, self.visualizer, self.logger, 
                                              is_eval=True, is_infer=True)
         
         rec_out = np.argmax(rec_out, axis=3)
         
         os.system('mkdir ins_result')
         
         for i in range(30):
             
             os.system('mkdir ins_result/%d' % (i))
             
             F = open('ins_result/%d/ins.txt' % (i), 'w')
             for j in range(rec_out.shape[1]):
                 print([glove_key[rec_out[i, j, k]] for k in range(rec_out.shape[2])])
                 print([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(rec_out.shape[2])])
                 print()
                 
                 F.write(' '.join([glove_key[rec_out[i, j, k]] for k in range(rec_out.shape[2])]))
                 F.write('\n')
                 F.write(' '.join([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(rec_out.shape[2])]))
                 F.write('\n')
                 F.write('\n')
                 
                 TV.utils.save_image(batch['image'][i, j].data, 'ins_result/%d/%d.png' % (i, j), normalize=True, range=(-1, 1))
             
             print('\n----------------\n')
             F.close()
         
         break
     
     os.system('tar zcvf ins_result.tar.gz ins_result')
    def get_dataloader(self,
                       batch_size=32,
                       shuffle=True,
                       device='cpu',
                       bucketing=False):
        def _collate_fn(batch, device=device):
            batch = batch[0]
            for name in batch.keys():
                if name == 'y':
                    batch[name] = [tensor.to(device) for tensor in batch[name]]
                elif isinstance(batch[name], torch.Tensor):
                    batch[name] = batch[name].to(device)
                elif isinstance(batch[name], dict):
                    batch[name] = _collate_fn([batch[name]], device)
            return batch

        if bucketing:  # use sequence bucketing
            sequence_lengths = torch.tensor(
                [len(self.tokenizer(sample.text)) for sample in self.data])
            dataloader = SequenceBucketing.as_dataloader(
                self, sequence_lengths, batch_size, shuffle)
        else:
            collator = self._Collator(self, device)
            dataloader = DataLoader(collator,
                                    batch_size=batch_size,
                                    shuffle=shuffle)
            _collate_fn = collator.collate_fn
        dataloader.collate_fn = _collate_fn

        return dataloader
Beispiel #3
0
 def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:
     """
     Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled.
     """
     dataloader.collate_fn = partial(
         _sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
     )
    def get_train_dataloaders(self) -> Tuple[DataLoader, DataLoader]:
        """Loads and returns train and validation(if available) dataloaders

        :return: the dataloaders
        :rtype: Tuple[DataLoader, DataLoader]
        """

        language = self._arguments_service.language

        train_dataset = self._dataset_service.get_dataset(
            RunType.Train, language)

        data_loader_train: DataLoader = DataLoader(
            train_dataset,
            batch_size=self._arguments_service.batch_size,
            shuffle=self._arguments_service.shuffle)

        if train_dataset.use_collate_function():
            data_loader_train.collate_fn = train_dataset.collate_function

        if not self._arguments_service.skip_validation:
            validation_dataset = self._dataset_service.get_dataset(
                RunType.Validation, language)

            data_loader_validation = DataLoader(
                validation_dataset,
                batch_size=self._arguments_service.batch_size,
                shuffle=False)

            if validation_dataset.use_collate_function():
                data_loader_validation.collate_fn = validation_dataset.collate_function
        else:
            data_loader_validation = None

        return (data_loader_train, data_loader_validation)
    def predict(self,
                dataloader: DataLoader,
                activation_fct=None,
                need_labels=False,
                ):
        self.to(self._target_device)
        self.eval()

        dataloader.collate_fn = self.batching_collate
        if activation_fct is None:
            activation_fct = nn.Sigmoid() if self.num_labels == 1 else nn.Identity()

        pred_scores = []
        target_labels = []
        with torch.no_grad():
            for features, labels in tqdm(dataloader, desc="Iteration", smoothing=0.05):
                outputs = self.forward(features)
                logits = activation_fct(outputs)
                pred_scores.extend(logits)
                target_labels.extend(labels)
        pred_scores = np.asarray([score.cpu().detach().numpy() for score in pred_scores])
        if need_labels:
            target_labels = np.asarray(target_labels)
            return pred_scores, target_labels
        return pred_scores
Beispiel #6
0
 def infer_gen(self):
     
     cfg = self.cfg
     
     if cfg.dataset == 'codraw':
         self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0.pt'))
     elif cfg.dataset == 'iclevr':
         self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0.pt'))
     
     dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size)
     dataloader = DataLoader(dataset, 
                             batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
     
     if cfg.dataset == 'codraw':
         dataloader.collate_fn = codraw_dataset.collate_data
     elif cfg.dataset == 'iclevr':
         dataloader.collate_fn = clevr_dataset.collate_data
     
     glove_key = list(dataset.glove_key.keys())
     
     for batch in dataloader:
         rec_out = self.model.infer_gen(batch)
         
         os.system('mkdir gen_result')
         
         for i in range(30):
             
             os.system('mkdir gen_result/%d' % (i))
             
             F = open('gen_result/%d/ins.txt' % (i), 'w')
             for j in range(rec_out.shape[1]):
                 F.write(' '.join([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(batch['turn_word'].shape[2])]))
                 F.write('\n')
                 
                 TV.utils.save_image(batch['image'][i, j].data, 'gen_result/%d/_%d.png' % (i, j), normalize=True, range=(-1, 1))
                 TV.utils.save_image(torch.from_numpy(rec_out[i, j]).data, 'gen_result/%d/%d.png' % (i, j), normalize=True, range=(-1, 1))
                 
             F.close()
         
         break
     
     os.system('tar zcvf gen_result.tar.gz gen_result')
Beispiel #7
0
def data_to_loader(args, torch_datasets):

    # use collate_fn to construct some: atom_mask and edge_mask
    train_loader = DataLoader(torch_datasets['train'],
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    train_loader.collate_fn = collate_fn

    test_loader = DataLoader(torch_datasets['test'],
                             batch_size=args.batch_size_test,
                             shuffle=True,
                             num_workers=args.num_workers)
    test_loader.collate_fn = collate_fn

    valid_loader = DataLoader(torch_datasets['valid'],
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_loader.collate_fn = collate_fn

    return train_loader, test_loader, valid_loader
Beispiel #8
0
 def eval_ctr(self, epoch, iteration_counter):
     cfg = self.cfg
     
     dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size)
     dataloader = DataLoader(dataset, 
                             batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
     
     if cfg.dataset == 'codraw':
         dataloader.collate_fn = codraw_dataset.collate_data
     elif cfg.dataset == 'iclevr':
         dataloader.collate_fn = clevr_dataset.collate_data
     
     if cfg.dataset == 'codraw':
         dataset.shuffle()
     rec_loss = []
     for batch in dataloader:
         loss = self.model.train_ctr(batch, epoch, iteration_counter, self.visualizer, self.logger, is_eval=True)
         rec_loss.append(loss)
     
     loss = np.average(rec_loss)
     
     return loss
    def fit(self,
            train_dataloader: DataLoader,
            evaluator=None,
            epochs: int = 1,
            optimizer_class: Type[torch.optim.Optimizer] = transformers.AdamW,
            optimizer_params: Dict[str, object] = {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            output_path: str = None,
            save_best_model: bool = True,
            max_grad_norm: float = 1
            ):
        self.to(self._target_device)
        self.best_score = -1e9

        tools.ensure_path_exist(output_path)

        train_dataloader.collate_fn = self.batching_collate
        num_train_steps = int(len(train_dataloader) * epochs)

        # Prepare optimizers
        param_optimizer = list(self.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
        for epoch in trange(epochs, desc='Epoch'):
            training_steps = 0
            self.zero_grad()
            self.train()

            for features, labels in tqdm(train_dataloader, desc="Iteration", smoothing=0.05):
                outputs = self.forward(features)
                loss_value = self.loss_fct(outputs, labels.float().view(-1))
                loss_value.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                training_steps += 1

                if evaluator is not None and evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps)
                    self.zero_grad()
                    self.train()

            if evaluator is not None:
                self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1)
Beispiel #10
0
    def load_dataloader2(self, arguments, set_name: str) -> DataLoader:
        """ loads specific dataset as a DataLoader """

        dataset: BaseDataset = find_right_model(
            DATASETS,
            arguments.dataset_class,
            folder=arguments.data_folder,
            set_name=set_name,
            genre=Genre.from_str(arguments.genre),
            normalize=arguments.normalize_data)

        loader = DataLoader(dataset,
                            shuffle=(set_name is TRAIN_SET),
                            batch_size=arguments.batch_size)

        if dataset.use_collate_function():
            loader.collate_fn = pad_and_sort_batch

        return loader
Beispiel #11
0
class Grid(Figure):
    def __init__(self, cfg, ncol):
        super(Grid, self).__init__(cfg)
        self.ncol = ncol
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=self.ncol,
                                     shuffle=False)
        self.input_imgs = self.dataloader.collate_fn(
            [self.dataset[i] for i in range(self.ncol)])

    @torch.no_grad()
    def draw(self, pl_module):
        grid = torchvision.utils.make_grid(torch.cat(list(
            self.create_rows(pl_module)),
                                                     dim=0),
                                           nrow=self.ncol)
        grid = grid.permute(1, 2, 0)
        grid = torch.clamp(grid, 0, 1)
        fig_array = grid.detach().cpu().numpy()
        return fig_array
    def fit(self,
            train_dataloader: DataLoader,
            evaluator: SentenceEvaluator = None,
            epochs: int = 1,
            loss_fct = None,
            activation_fct = nn.Identity(),
            scheduler: str = 'WarmupLinear',
            warmup_steps: int = 10000,
            optimizer_class: Type[Optimizer] = transformers.AdamW,
            optimizer_params: Dict[str, object] = {'lr': 2e-5},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            output_path: str = None,
            save_best_model: bool = True,
            max_grad_norm: float = 1,
            use_amp: bool = False,
            callback: Callable[[float, int, int], None] = None,
            ):
        """
        Train the model with the given training objective
        Each training objective is sampled in turn for one batch.
        We sample only as many batches from each objective as there are in the smallest one
        to make sure of equal training with each dataset.

        :param train_dataloader: DataLoader with training InputExamples
        :param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.
        :param epochs: Number of epochs for training
        :param loss_fct: Which loss function to use for training. If None, will use nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()
        :param activation_fct: Activation function applied on top of logits output of model.
        :param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        :param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
        :param optimizer_class: Optimizer
        :param optimizer_params: Optimizer parameters
        :param weight_decay: Weight decay for model parameters
        :param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps
        :param output_path: Storage path for the model and evaluation files
        :param save_best_model: If true, the best model (according to evaluator) is stored at output_path
        :param max_grad_norm: Used for gradient normalization.
        :param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
        :param callback: Callback function that is invoked after each evaluation.
                It must accept the following three parameters in this order:
                `score`, `epoch`, `steps`
        """
        train_dataloader.collate_fn = self.smart_batching_collate

        if use_amp:
            from torch.cuda.amp import autocast
            scaler = torch.cuda.amp.GradScaler()

        self.model.to(self._target_device)

        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)

        self.best_score = -9999999
        num_train_steps = int(len(train_dataloader) * epochs)

        # Prepare optimizers
        param_optimizer = list(self.model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [
                    p
                    for n, p in param_optimizer
                    if all(nd not in n for nd in no_decay)
                ],
                'weight_decay': weight_decay,
            },
            {
                'params': [
                    p
                    for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay': 0.0,
            },
        ]


        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)

        if isinstance(scheduler, str):
            scheduler = SentenceTransformer._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)

        if loss_fct is None:
            loss_fct = nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()


        skip_scheduler = False
        for epoch in trange(epochs, desc="Epoch"):
            self.model.zero_grad()
            self.model.train()

            for training_steps, (features, labels) in enumerate(tqdm(train_dataloader, desc="Iteration", smoothing=0.05), start=1):
                if use_amp:
                    with autocast():
                        model_predictions = self.model(**features, return_dict=True)
                        logits = activation_fct(model_predictions.logits)
                        if self.config.num_labels == 1:
                            logits = logits.view(-1)
                        loss_value = loss_fct(logits, labels)

                    scale_before_step = scaler.get_scale()
                    scaler.scale(loss_value).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    skip_scheduler = scaler.get_scale() != scale_before_step
                else:
                    model_predictions = self.model(**features, return_dict=True)
                    logits = activation_fct(model_predictions.logits)
                    if self.config.num_labels == 1:
                        logits = logits.view(-1)

                    loss_value = loss_fct(logits, labels)
                    loss_value.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                    optimizer.step()

                optimizer.zero_grad()

                if not skip_scheduler:
                    scheduler.step()

                if evaluator is not None and evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps, callback)

                    self.model.zero_grad()
                    self.model.train()

            if evaluator is not None:
                self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)
Beispiel #13
0
def visual_feature_map(model, path, device):
    mean, std = train_classify.load_mestd()
    bz = 16
    visual_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    visual_img_transform = transforms.Compose([transforms.ToTensor()])
    visual_set = torchvision.datasets.CIFAR10(root='./data',
                                              train=False,
                                              download=True,
                                              transform=visual_transform)
    visual_img_set = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=visual_img_transform)
    visual_loader = DataLoader(dataset=visual_set, batch_size=bz)
    visual_img_loader = DataLoader(dataset=visual_img_set, batch_size=bz)

    plot_path = os.path.join(path, 'feature_map/')
    if not os.path.exists(plot_path):
        os.makedirs(plot_path)
    with torch.no_grad():
        for n, (img, label) in enumerate(visual_loader):
            img, label = img.to(device), label.to(device)
            img2, label2 = visual_img_loader.collate_fn(
                [visual_img_loader.dataset[i] for i in range(n, n + bz)])

            img2 = img2.to(device)

            feature1 = model.output_feature_conv1(img)
            visual_in_grid(feature1.cpu().permute(1, 0, 2, 3).numpy(),
                           'conv1_feature_map',
                           plot_path,
                           split_channels=True)
            # plt.close()
            feature5 = model.output_feature_conv5(img)
            visual_in_grid(feature5.cpu().permute(1, 0, 2, 3).numpy(),
                           'conv5_feature_map',
                           plot_path,
                           split_channels=True)

            deconv1_model, deconv5_model = load_deconv_model(model)
            deconv1_model = deconv1_model.to(device)
            deconv5_model = deconv5_model.to(device)

            recon_img1_all = torch.zeros(
                (feature1.size()[1], bz, img.size()[1], img.size()[2],
                 img.size()[3]))
            for i in range(feature1.size()[1]):
                one_map = extract_one_feature(feature1.cpu(), i).to(device)

                recon_img1 = deconv1_model(one_map)
                recon_img1_all[i] = recon_img1

                recon_show = torch.cat((img2, recon_img1), 0)
                visual_in_grid(recon_show.cpu().permute(0, 2, 3, 1).numpy(),
                               'recons_img_conv1_fea{}'.format(i),
                               plot_path,
                               split_channels=False)
                plt.close()

            visual_in_grid(recon_img1_all.cpu().permute(0, 1, 3, 4, 2).numpy(),
                           'recons_img1_all',
                           plot_path,
                           split_channels=True,
                           normalize=True)
            plt.close()

            recon_img5_all = torch.zeros(
                (feature5.size()[1], bz, img.size()[1], img.size()[2],
                 img.size()[3]))
            for i in range(feature5.size()[1]):
                one_map = extract_one_feature(feature5.cpu(), i).to(device)
                # print(one_map.size())
                recon_img5 = deconv5_model(one_map)
                recon_img5_all[i] = recon_img5
                # print('recon:{}'.format(recon_img1.size()))
                recon_show = torch.cat((img2, recon_img5), 0)
                visual_in_grid(recon_show.cpu().permute(0, 2, 3, 1).numpy(),
                               'recons_img_conv5_fea{}'.format(i),
                               plot_path,
                               split_channels=False)
                plt.close()

            visual_in_grid(recon_img5_all.cpu().permute(0, 1, 3, 4, 2).numpy(),
                           'recons_img5_all',
                           plot_path,
                           split_channels=True,
                           normalize=True)
            plt.close()

            break
Beispiel #14
0
colors = [(1, 0, 0.5), (1, 1, 1), (0, 0.5, 0)]
cmap = LinearSegmentedColormap.from_list('rg', colors, N=256)
plt.register_cmap(cmap=cmap)
over = [
    29714, 29234, 23436, 25512, 15542, 22496, 16544, 6301, 19823, 23196, 25747,
    24913, 15885, 13164, 18695, 5973, 20092, 21161, 26815, 28571
]
under = [
    6184, 14914, 20894, 5644, 6308, 10608, 9823, 12829, 4747, 9134, 12527,
    27883, 12319, 24707, 16316, 3199, 3292, 8681, 19292, 20234
]
for i in under:  #range(iter_num):

    data = dataset[i]
    data = dataloader_son_val.collate_fn([data])
    #Baseline
    out_sun2, out_son2 = net2(data['image'].to(device))
    maps2 = basenet2(data['image'].to(device))
    maps2 = maps2.squeeze().cpu().numpy()
    map2 = np.zeros(maps2.shape[1:])
    for m in range(maps2.shape[0]):
        map2 += maps2[m, :, :] * contrib[m]

    map2[0:2, :] = 0
    map2[-2:, :] = 0
    map2[:, 0:2] = 0
    map2[:, -2:] = 0
    # Ours
    out_sun, out_son, maps, attr_contrib = net(data['image'].to(device))
    map = np.zeros(maps.shape[2:])
Beispiel #15
0
def train(train_cluster_data, val_cluster_data, test_cluster_data, output_path, eval_steps,
          num_epochs, warmup_frac, lambda_val, reg, use_model_device, max_train_size=-1, train_psg_model=False,
          model_name='distilbert-base-uncased', out_features=256, steps_per_epoch=None, weight_decay=0.01,
          optimizer_class=transformers.AdamW, scheduler='WarmupLinear', optimizer_params={'lr':2e-5},
          show_progress_bar=True, max_grad_norm=1, save_best_model=True):
    tensorboard_writer = SummaryWriter('./tensorboard_logs')
    task = Task.init(project_name='Query Specific BB Clustering', task_name='query_bbc_fixed_lambda')
    config_dict = {'lambda_val': lambda_val, 'reg': reg}
    config_dict = task.connect(config_dict)
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print('CUDA is available and using device: '+str(device))
    else:
        device = torch.device('cpu')
        print('CUDA not available, using device: '+str(device))
    ### Configure sentence transformers for training and train on the provided dataset
    # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
    query_word_embedding_model = models.Transformer(model_name)

    # Apply mean pooling to get one fixed sized sentence vector
    query_pooling_model = models.Pooling(query_word_embedding_model.get_word_embedding_dimension(),
                                   pooling_mode_mean_tokens=True,
                                   pooling_mode_cls_token=False,
                                   pooling_mode_max_tokens=False)

    query_dense_model = models.Dense(in_features=query_pooling_model.get_sentence_embedding_dimension(),
                                     out_features=out_features,
                                     activation_function=nn.Sigmoid())
    psg_word_embedding_model = models.Transformer(model_name)

    # Apply mean pooling to get one fixed sized sentence vector
    psg_pooling_model = models.Pooling(psg_word_embedding_model.get_word_embedding_dimension(),
                                         pooling_mode_mean_tokens=True,
                                         pooling_mode_cls_token=False,
                                         pooling_mode_max_tokens=False)

    psg_dense_model = models.Dense(in_features=psg_pooling_model.get_sentence_embedding_dimension(),
                                     out_features=out_features,
                                     activation_function=nn.Tanh())

    query_model = CustomSentenceTransformer(modules=[query_word_embedding_model, query_pooling_model,
                                                     query_dense_model])
    psg_model = SentenceTransformer(modules=[psg_word_embedding_model, psg_pooling_model, psg_dense_model])

    model = QuerySpecificClusterModel(query_transformer=query_model, psg_transformer=psg_model, device=device)

    train_dataloader = DataLoader(train_cluster_data, shuffle=True, batch_size=1)
    evaluator = QueryClusterEvaluator.from_input_examples(val_cluster_data, use_model_device)
    test_evaluator = QueryClusterEvaluator.from_input_examples(test_cluster_data, use_model_device)

    warmup_steps = int(len(train_dataloader) * num_epochs * warmup_frac)  # 10% of train data

    print("Untrained performance")
    model.to(device)
    evaluator(model)

    train_dataloader.collate_fn = model.query_batch_collate_fn

    # Train the model
    best_score = -9999999
    if steps_per_epoch is None or steps_per_epoch == 0:
        steps_per_epoch = len(train_dataloader)
    num_train_steps = int(steps_per_epoch * num_epochs)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    data_iter = iter(train_dataloader)
    optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
    scheduler_obj = model._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps,
                                        t_total=num_train_steps)
    config = {'epochs': num_epochs, 'steps_per_epoch': steps_per_epoch}
    global_step = 0
    loss_model = BBClusterLossModel(model, device, lambda_val, reg)
    for epoch in trange(config.get('epochs'), desc="Epoch", disable=not show_progress_bar):
        training_steps = 0
        running_loss_0 = 0.0
        model.zero_grad()
        model.train()
        if not train_psg_model:
            for m in model.psg_model.modules():
                m.training = False
        for _ in trange(config.get('steps_per_epoch'), desc="Iteration", smoothing=0.05, disable=not show_progress_bar):
            try:
                data = next(data_iter)
            except StopIteration:
                data_iter = iter(train_dataloader)
                data = next(data_iter)
            query_feature, psg_features, labels = data
            if max_train_size > 0 and labels.shape[1] > max_train_size:
                print('skipping instance with '+str(labels.shape[1])+' passages')
                continue
            loss_val = loss_model(query_feature, psg_features, labels)
            running_loss_0 += loss_val.item()
            loss_val.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            scheduler_obj.step()
            training_steps += 1
            global_step += 1

            if eval_steps > 0 and training_steps % eval_steps == 0:
                tensorboard_writer.add_scalar('training_loss', running_loss_0 / eval_steps, global_step)
                # logger.report_scalar('Loss', 'training_loss', iteration=global_step, v
                # alue=running_loss_0/evaluation_steps)
                running_loss_0 = 0.0
                # self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps, callback)
                if evaluator is not None:
                    score = evaluator(model, output_path=output_path, epoch=epoch, steps=training_steps)
                    tensorboard_writer.add_scalar('val_ARI', score, global_step)
                    # logger.report_scalar('Training progress', 'val_ARI', iteration=global_step, value=score)
                    if score > best_score:
                        best_score = score
                        if save_best_model:
                            print('Saving model at: ' + output_path)
                            model.save(output_path)
                model.zero_grad()
                model.train()
                if not train_psg_model:
                    for m in model.psg_model.modules():
                        m.training = False
        if evaluator is not None:
            score = evaluator(model, output_path=output_path, epoch=epoch, steps=training_steps)
            tensorboard_writer.add_scalar('val_ARI', score, global_step)
            # logger.report_scalar('Training progress', 'val_ARI', iteration=global_step, value=score)
            if score > best_score:
                best_score = score
                if save_best_model:
                    model.save(output_path)
        if test_evaluator is not None:
            best_model = QuerySpecificClusterModel(output_path)
            if torch.cuda.is_available():
                model.to(torch.device('cpu'))
                best_model.to(device)
                test_ari = test_evaluator(best_model)
                best_model.to(torch.device('cpu'))
                model.to(device)
            else:
                test_ari = test_evaluator(best_model)
            tensorboard_writer.add_scalar('test_ARI', test_ari, global_step)
            # logger.report_scalar('Training progress', 'test_ARI', iteration=global_step, value=test_ari)
    if evaluator is None and output_path is not None:  # No evaluator, but output path: save final model version
        model.save(output_path)