예제 #1
0
 def _init_svae_model(self):
     self.svae = SVAE(**self.model_config['SVAE']['Parameters'],
                      vocab_size=self.vocab_size).to(self.device)
     self.svae_optim = optim.Adam(
         self.svae.parameters(),
         lr=self.model_config['SVAE']['learning_rate'])
     self.svae.train()
     print(f'{datetime.now()}: Initialised SVAE successfully')
예제 #2
0
 def _load_pretrained_model(self):
     # Initialise SVAE with saved parameters. TODO: Save model hyperparameters to disk with the saved weights
     self.svae = SVAE(**self.model_config['SVAE']['Parameters'], vocab_size=self.vocab_size).to(self.device)
     # Loads pre-trained SVAE model from disk and modifies
     svae_best_model_path = 'best models/svae.pt'
     self.svae.load_state_dict(torch.load(svae_best_model_path))
     print(f'{datetime.now()}: Loaded pretrained SVAE\n{self.svae}')
     
     class Identity(nn.Module):
         def __init__(self):
             super(Identity, self).__init__()
         def forward(self, x):
             return x
     
     self.svae.outputs2vocab=Identity()  # Removes hidden2output layer
     print(f'{datetime.now()}: Modified pretrained SVAE\n{self.svae}')
예제 #3
0
    def _init_models(self, mode: str):
        """ Initialises models, loss functions, optimisers and sets models to training mode """
        # Task Learner
        self.task_learner = TaskLearner(**self.model_config['TaskLearner']['Parameters'],
                                        vocab_size=self.vocab_size,
                                        tagset_size=self.tagset_size,
                                        task_type=self.task_type).to(self.device)
        # Loss functions
        if self.task_type == 'SEQ':
            self.tl_loss_fn = nn.NLLLoss().to(self.device)
        if self.task_type == 'CLF':
            self.tl_loss_fn = nn.CrossEntropyLoss().to(self.device)

        # Optimisers
        self.tl_optim = optim.SGD(self.task_learner.parameters(),
                                  lr=self.model_config['TaskLearner']['learning_rate'])#, momentum=0, weight_decay=0.1)
        
        # Learning rate scheduler
        # Note: LR likely GT Adam
        # self.tl_sched = optim.lr_scheduler.ReduceLROnPlateau(self.tl_optim, 'min', factor=0.5, patience=10)
        # Training Modes
        self.task_learner.train()

        # SVAAL needs to initialise SVAE and DISC in addition to TL
        if mode == 'svaal':
            # Models
            self.svae = SVAE(**self.model_config['SVAE']['Parameters'],
                             vocab_size=self.vocab_size).to(self.device)
            self.discriminator = Discriminator(**self.model_config['Discriminator']['Parameters']).to(self.device)
            
            # Loss Function (SVAE defined within its class)
            self.dsc_loss_fn = nn.BCELoss().to(self.device)
            
            # Optimisers
            # Note: Adam will likely have a lower lr than SGD
            self.svae_optim = optim.Adam(self.svae.parameters(),
                                         lr=self.model_config['SVAE']['learning_rate'])
            self.dsc_optim = optim.Adam(self.discriminator.parameters(),
                                        lr=self.model_config['Discriminator']['learning_rate'])
            
            # Training Modes
            self.svae.train()
            self.discriminator.train()

        print(f'{datetime.now()}: Models initialised successfully')
예제 #4
0
class Trainer:
    """ Prepares and trains S-VAAL model """
    def __init__(self):
        self.config = load_config()
        self.model_config = self.config['Models']

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Model
        self.task_type = self.config['Utils']['task_type']
        self.max_sequence_length = self.config['Utils'][self.task_type]['max_sequence_length']
        
        # Real data
        self.data_name = self.config['Utils'][self.task_type]['data_name']
        self.data_splits = self.config['Utils'][self.task_type]['data_split']
        self.pad_idx = self.config['Utils']['special_token2idx']['<PAD>']
        
        # Test run properties
        self.epochs = self.config['Train']['epochs']
        self.svae_iterations = self.config['Train']['svae_iterations']
        self.dsc_iterations = self.config['Train']['discriminator_iterations']
        self.adv_hyperparam = self.config['Models']['SVAE']['adversarial_hyperparameter']

    def _init_dataset(self, batch_size=None):
        """ Initialise real datasets by reading encoding data
        Returns
        -------
            self : dict
                Dictionary of DataLoaders
        Notes
        -----
        - Task type and data name are specified in the configuration file
        - Keys in 'data' are the splits used and the keys in 'vocab' are words and tags
        """
        
        kfold_xval = False
        
        self.x_y_pair_name = 'seq_label_pairs_enc' if self.data_name == 'ag_news' else 'seq_tags_pairs_enc' # Key in dataset - semantically correct for the task at hand.

        if batch_size is None:
            batch_size = self.config['Train']['batch_size']


        # Load pre-processed data
        path_data = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data', self.task_type, self.data_name, 'data.json')
        path_vocab = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data', self.task_type, self.data_name, 'vocabs.json')
        data = load_json(path_data)
        self.vocab = load_json(path_vocab)       # Required for decoding sequences for interpretations. TODO: Find suitable location... or leave be...
        self.vocab_size = len(self.vocab['words'])  # word vocab is used for model dimensionality setting + includes special characters (EOS, SOS< UNK, PAD)
        self.tagset_size = len(self.vocab['tags'])  # this includes special characters (EOS, SOS, UNK, PAD)

        self.datasets = dict()
        if kfold_xval:
            # Perform k-fold cross-validation
            # Join all datasets and then randomly assign train/val/test
            print('hello')
            
            
            for split in self.data_splits:
                print(data[split][self.x_y_pair_name])
            
            
        else:    
            for split in self.data_splits:
                # Access data
                split_data = data[split][self.x_y_pair_name]
                # Convert lists of encoded sequences into tensors and stack into one large tensor
                split_seqs = torch.stack([torch.tensor(enc_pair[0]) for key, enc_pair in split_data.items()])
                split_tags = torch.stack([torch.tensor(enc_pair[1]) for key, enc_pair in split_data.items()])
                # Create torch dataset from tensors
                split_dataset = RealDataset(sequences=split_seqs, tags=split_tags)
                # Add to dictionary
                self.datasets[split] = split_dataset #split_dataloader
                
                # Create torch dataloader generator from dataset
                if split == 'test':
                    self.test_dataloader = DataLoader(dataset=split_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
                if split == 'valid':
                    self.val_dataloader = DataLoader(dataset=split_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

        print(f'{datetime.now()}: Data initialised succesfully')

    def _init_models(self, mode: str):
        """ Initialises models, loss functions, optimisers and sets models to training mode """
        # Task Learner
        self.task_learner = TaskLearner(**self.model_config['TaskLearner']['Parameters'],
                                        vocab_size=self.vocab_size,
                                        tagset_size=self.tagset_size,
                                        task_type=self.task_type).to(self.device)
        # Loss functions
        if self.task_type == 'SEQ':
            self.tl_loss_fn = nn.NLLLoss().to(self.device)
        if self.task_type == 'CLF':
            self.tl_loss_fn = nn.CrossEntropyLoss().to(self.device)

        # Optimisers
        self.tl_optim = optim.SGD(self.task_learner.parameters(),
                                  lr=self.model_config['TaskLearner']['learning_rate'])#, momentum=0, weight_decay=0.1)
        
        # Learning rate scheduler
        # Note: LR likely GT Adam
        # self.tl_sched = optim.lr_scheduler.ReduceLROnPlateau(self.tl_optim, 'min', factor=0.5, patience=10)
        # Training Modes
        self.task_learner.train()

        # SVAAL needs to initialise SVAE and DISC in addition to TL
        if mode == 'svaal':
            # Models
            self.svae = SVAE(**self.model_config['SVAE']['Parameters'],
                             vocab_size=self.vocab_size).to(self.device)
            self.discriminator = Discriminator(**self.model_config['Discriminator']['Parameters']).to(self.device)
            
            # Loss Function (SVAE defined within its class)
            self.dsc_loss_fn = nn.BCELoss().to(self.device)
            
            # Optimisers
            # Note: Adam will likely have a lower lr than SGD
            self.svae_optim = optim.Adam(self.svae.parameters(),
                                         lr=self.model_config['SVAE']['learning_rate'])
            self.dsc_optim = optim.Adam(self.discriminator.parameters(),
                                        lr=self.model_config['Discriminator']['learning_rate'])
            
            # Training Modes
            self.svae.train()
            self.discriminator.train()

        print(f'{datetime.now()}: Models initialised successfully')

    def train(self, dataloader_l, dataloader_u, dataloader_v, dataloader_t, mode: str, meta: str):
        """ 
        Sequentially train S-VAAL in the following training sequence:
            ```
                for epoch in epochs:
                    train Task Learner
                    for step in steps:
                        train SVAE
                    for step in steps:
                        train Discriminator
            ```
        Arguments
        ---------
            dataloader_l : TODO
                DataLoader for labelled data
            dataloader_u : TODO
                DataLoader for unlabelled data
            dataloader_v : TODO
                DataLoader for validation data
            mode : str
                Training mode (svaal, random, least_confidence, etc.)
            meta : str
                Meta data about the current training run

        Returns
        -------
            eval_metrics : tuple
                Task dependent evaluation metrics (F1 micro/macro or Accuracy)
            svae : TODO
                Sentence variational autoencoder
            discriminator : TODO
                Discriminator

        Notes
        -----

        """
        self.tb_writer = SummaryWriter(comment=meta, filename_suffix=meta)

        early_stopping = EarlyStopping(patience=self.config['Train']['es_patience'], verbose=True, path="checkpoints/checkpoint.pt")  # TODO: Set EarlyStopping params in config

        dataset_size = len(dataloader_l) + len(dataloader_u) if dataloader_u is not None else len(dataloader_l)
        train_iterations = dataset_size * (self.epochs+1)
        print(f'{datetime.now()}: Dataset size {dataset_size} Training iterations {train_iterations}')


        write_freq = 50 # number of iters to write to TensorBoard
        train_str = ''
        step = 0    # Used for KL annealing
        epoch = 1
        for train_iter in tqdm(range(train_iterations), desc='Training iteration'):            
            batch_sequences_l, batch_lengths_l, batch_tags_l =  next(iter(dataloader_l))

            if torch.cuda.is_available():
                batch_sequences_l = batch_sequences_l.to(scurrent_indicesdataloader_u))
                batch_sequences_u = batch_sequences_u.to(self.device)
                batch_length_u = batch_lengths_u.to(self.device)

            # Strip off tag padding and flatten
            # Don't do sequences here as its done in the forward pass of the seq2seq models
            batch_tags_l = trim_padded_seqs(batch_lengths=batch_lengths_l,
                                            batch_sequences=batch_tags_l,
                                            pad_idx=self.pad_idx).view(-1)

            # Task Learner Step
            self.tl_optim.zero_grad()
            tl_preds = self.task_learner(batch_sequences_l, batch_lengths_l)
            tl_loss = self.tl_loss_fn(tl_preds, batch_tags_l)
            tl_loss.backward()
            self.tl_optim.step()
            
            if (train_iter > 0) & (train_iter % dataset_size == 0):
                # TODO: Reinstate LR scheduling in the future
                # Decay learning rate at the end of each epoch (if required)
                # self.tl_sched.step(tl_loss)     # Decay learning rate
                
                # Manually decay LR at each epoch
                # self.tl_optim.param_groups[0]["lr"] = self.tl_optim.param_groups[0]["lr"] / 10
                pass

            if mode == 'svaal':
                # Used in SVAE and Discriminator
                batch_size_l = batch_sequences_l.size(0)
                batch_size_u = batch_sequences_u.size(0)

                # SVAE Step
                # TODO: Extend for unsupervised - need to review svae.loss_fn for unsupervised case
                for i in range(self.svae_iterations):
                    # Labelled and unlabelled forward passes through SVAE and loss computation
                    logp_l, mean_l, logv_l, z_l = self.svae(batch_sequences_l, batch_lengths_l)
                    NLL_loss_l, KL_loss_l, KL_weight_l = self.svae.loss_fn(
                                                                    logp=logp_l,
                                                                    target=batch_sequences_l,
                                                                    length=batch_lengths_l,
                                                                    mean=mean_l,
                                                                    logv=logv_l,
                                                                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                                                                    step=step,
                                                                    k=self.model_config['SVAE']['k'],
                                                                    x0=self.model_config['SVAE']['x0'])

                    logp_u, mean_u, logv_u, z_u = self.svae(batch_sequences_u, batch_lengths_u)
                    NLL_loss_u, KL_loss_u, KL_weight_u = self.svae.loss_fn(
                                                                    logp=logp_u,
                                                                    target=batch_sequences_u,
                                                                    length=batch_lengths_u,
                                                                    mean=mean_u,
                                                                    logv=logv_u,
                                                                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                                                                    step=step,
                                                                    k=self.model_config['SVAE']['k'],
                                                                    x0=self.model_config['SVAE']['x0'])
                    # VAE loss
                    svae_loss_l = (NLL_loss_l + KL_weight_l * KL_loss_l) / batch_size_l
                    svae_loss_u = (NLL_loss_u + KL_weight_u * KL_loss_u) / batch_size_u

                    # Adversary loss - trying to fool the discriminator!
                    dsc_preds_l = self.discriminator(z_l)   # mean_l
                    dsc_preds_u = self.discriminator(z_u)   # mean_u
                    dsc_real_l = torch.ones(batch_size_l)
                    dsc_real_u = torch.ones(batch_size_u)

                    if torch.cuda.is_available():
                        dsc_real_l = dsc_real_l.to(self.device)
                        dsc_real_u = dsc_real_u.to(self.device)

                    # Higher loss = discriminator is having trouble figuring out the real vs fake
                    # Generator wants to maximise this loss
                    adv_dsc_loss_l = self.dsc_loss_fn(dsc_preds_l, dsc_real_l)
                    adv_dsc_loss_u = self.dsc_loss_fn(dsc_preds_u, dsc_real_u)
                    adv_dsc_loss = adv_dsc_loss_l + adv_dsc_loss_u

                    total_svae_loss = svae_loss_u + svae_loss_l + self.adv_hyperparam * adv_dsc_loss
                    self.svae_optim.zero_grad()
                    total_svae_loss.backward()
                    self.svae_optim.step()

                    # Add scalar for adversarial loss
                    # self.tb_writer.add_scalar('Loss/SVAE/train/labelled/ADV', NLL_loss_l, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/unabelled/ADV', NLL_loss_l, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/ADV_total', NLL_loss_l, i + (train_iter*self.svae_iterations))
                    # Add scalars for ELBO (NLL), KL divergence, and Total loss 
                    # self.tb_writer.add_scalar('Utils/SVAE/train/kl_weight_l', KL_weight_l, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Utils/SVAE/train/kl_weight_u', KL_weight_u, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/labelled/NLL', NLL_loss_l, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/unlabelled/NLL', NLL_loss_u, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/labelled/KL_loss', KL_loss_l, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/unlabelled/KL_loss', KL_loss_u, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/labelled/total', svae_loss_l, i + (train_iter*self.svae_iterations))
                    # self.tb_writer.add_scalar('Loss/SVAE/train/unlabelled/total', svae_loss_u, i + (train_iter*self.svae_iterations))

                    # Sample new batch of data while training adversarial network
                    if i < self.svae_iterations - 1:
                        batch_sequences_l, batch_lengths_l, _ =  next(iter(dataloader_l))
                        batch_sequences_u, batch_length_u, _ = next(iter(dataloader_u))

                        if torch.cuda.is_available():
                            batch_sequences_l = batch_sequences_l.to(self.device)
                            batch_lengths_l = batch_lengths_l.to(self.device)
                            batch_sequences_u = batch_sequences_u.to(self.device)
                            batch_length_u = batch_length_u.to(self.device)
                        
                    # Increment step
                    step += 1

                # SVAE train_iter loss after iterative cycle
                self.tb_writer.add_scalar('Loss/SVAE/train/Total', total_svae_loss, train_iter)

                # Discriminator Step
                for j in range(self.dsc_iterations):

                    with torch.no_grad():
                        _, mean_l, _, z_l = self.svae(batch_sequences_l, batch_lengths_l)
                        _, mean_u, _, z_u = self.svae(batch_sequences_u, batch_lengths_u)

                    dsc_preds_l = self.discriminator(z_l)  #mean_l
                    dsc_preds_u = self.discriminator(z_u)  #mean_u

                    dsc_real_l = torch.ones(batch_size_l)
                    dsc_real_u = torch.zeros(batch_size_u)

                    if torch.cuda.is_available():
                        dsc_real_l = dsc_real_l.to(self.device)
                        dsc_real_u = dsc_real_u.to(self.device)

                    # Discriminator wants to minimise the loss here
                    dsc_loss_l = self.dsc_loss_fn(dsc_preds_l, dsc_real_l)
                    dsc_loss_u = self.dsc_loss_fn(dsc_preds_u, dsc_real_u)
                    total_dsc_loss = dsc_loss_l + dsc_loss_u
                    self.dsc_optim.zero_grad()
                    total_dsc_loss.backward()
                    self.dsc_optim.step()

                    # Sample new batch of data while training adversarial network
                    if j < self.dsc_iterations - 1:
                        # TODO: strip out unnecessary information
                        batch_sequences_l, batch_lengths_l, _ =  next(iter(dataloader_l))
                        batch_sequences_u, batch_length_u, _ = next(iter(dataloader_u))

                        if torch.cuda.is_available():
                            batch_sequences_l = batch_sequences_l.to(self.device)
                            batch_lengths_l = batch_lengths_l.to(self.device)
                            batch_sequences_u = batch_sequences_u.to(self.device)
                            batch_length_u = batch_length_u.to(self.device)
예제 #5
0
def main(args):
    def interpolate(start, end, steps):
        interpolation = np.zeros((start.shape[0], steps + 2))
        for dim, (s, e) in enumerate(zip(start, end)):
            interpolation[dim] = np.linspace(s, e, steps + 2)

        return interpolation.T

    def idx2word(sent_list, i2w, pad_idx):
        sent = []
        for s in sent_list:
            sent.append(" ".join([i2w[str(int(idx))] \
                                 for idx in s if int(idx) is not pad_idx]))
        return sent

    with open(args.data_dir + '/vocab.json', 'r') as file:
        vocab = json.load(file)
    w2i, i2w = vocab['w2i'], vocab['i2w']

    #Load model
    model = SVAE(
        vocab_size=len(w2i),
        embed_dim=args.embedding_dimension,
        hidden_dim=args.hidden_dimension,
        latent_dim=args.latent_dimension,
        teacher_forcing=False,
        dropout=args.dropout,
        n_direction=(2 if args.bidirectional else 1),
        n_parallel=args.n_layer,
        max_src_len=args.max_src_length,  #influence in inference stage
        max_tgt_len=args.max_tgt_length,
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
    )

    path = os.path.join('checkpoint', args.load_checkpoint)
    if not os.path.exists(path):
        raise FileNotFoundError(path)

    model.load_state_dict(torch.load(path))
    print("Model loaded from %s" % (path))

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    samples, z = model.inference(n=args.num_samples)
    print('----------SAMPLES----------')
    print(*idx2word(sent_list=samples, i2w=i2w, pad_idx=w2i['<pad>']),
          sep='\n')

    z1 = torch.randn([args.latent_dimension]).numpy()
    z2 = torch.randn([args.latent_dimension]).numpy()
    z = torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float()
    samples, _ = model.inference(z=z)
    print('-------INTERPOLATION-------')
    print(*idx2word(sent_list=samples, i2w=i2w, pad_idx=w2i['<pad>']),
          sep='\n')
예제 #6
0
def main(args):

    splits = ['train', 'valid'] + (['dev'] if args.test else [])
    print(args)
    #Load dataset
    datasets = OrderedDict()
    for split in splits:
        datasets[split] = seq_data(data_dir=args.data_dir,
                                   split=split,
                                   mt=args.mt,
                                   create_data=args.create_data,
                                   max_src_len=args.max_src_length,
                                   max_tgt_len=args.max_tgt_length,
                                   min_occ=args.min_occ)
    print('Data OK')
    #Load model
    model = SVAE(
        vocab_size=datasets['train'].vocab_size,
        embed_dim=args.embedding_dimension,
        hidden_dim=args.hidden_dimension,
        latent_dim=args.latent_dimension,

        #word_drop=args.word_dropout,
        teacher_forcing=args.teacher_forcing,
        dropout=args.dropout,
        n_direction=args.bidirectional,
        n_parallel=args.n_layer,
        attn=args.attention,
        max_src_len=args.max_src_length,  #influence in inference stage
        max_tgt_len=args.max_tgt_length,
        sos_idx=datasets['train'].sos_idx,
        eos_idx=datasets['train'].eos_idx,
        pad_idx=datasets['train'].pad_idx,
        unk_idx=datasets['train'].unk_idx)
    if args.fasttext:
        prt = torch.load(args.data_dir + '/prt_fasttext.model')
        model.load_prt(prt)
    print('Model OK')
    if torch.cuda.is_available():
        model = model.cuda()
    device = model.device
    #Training phase with validation(earlystopping)
    tracker = Tracker(patience=10,
                      verbose=True)  #record training history & es function
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    step = 0

    for epoch in range(args.epochs):
        for split in splits:
            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.n_batch,
                                     shuffle=(split == 'train'),
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())
            if split == 'train':
                model.train()
            else:
                model.eval()

            #Executing
            for i, data in enumerate(data_loader):
                src, srclen,  tgt, tgtlen = \
                     data['src'], data['srclen'], data['tgt'], data['tgtlen']
                #FP
                logits, (mu, logv,
                         z), generations = model(src, srclen, tgt, tgtlen,
                                                 split)

                #FP for groundtruth
                #h_pred, h_tgt = model.forward_gt(generations, tgt, tgtlen)

                #LOSS(weighted)
                NLL, KL, KL_W = model.loss(logits, tgt.to(device),
                                           data['tgtlen'], mu, logv, step,
                                           args.k, args.x0, args.af)
                #GLOBAL = model.global_loss(h_pred, h_tgt)
                GLOBAL = 0

                loss = (NLL + KL * KL_W + GLOBAL) / data['src'].size(0)
                #BP & OPTIM
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                #RECORD & RESULT(batch)
                if i % 50 == 0 or i + 1 == len(data_loader):
                    #NLL.data = torch.cuda.FloatTensor([NLL.data])
                    #KL.data = torch.cuda.FloatTensor([KL.data])
                    print(
                        "{} Phase - Batch {}/{}, Loss: {}, NLL: {}, KL: {}, KL-W: {}, G: {}"
                        .format(split.upper(), i,
                                len(data_loader) - 1, loss, NLL, KL, KL_W,
                                GLOBAL))
                tracker._elbo(torch.Tensor([loss]))
                if split == 'valid':
                    tracker.record(tgt, generations, datasets['train'].i2w,
                                   datasets['train'].pad_idx,
                                   datasets['train'].eos_idx,
                                   datasets['train'].unk_idx, z)

            #SAVING & RESULT(epoch)
            if split == 'valid':
                tracker.dumps(epoch, args.dump_file)  #dump the predicted text.
            else:
                tracker._save_checkpoint(
                    epoch, args.model_file,
                    model.state_dict())  #save the checkpooint
            print("{} Phase - Epoch {} , Mean ELBO: {}".format(
                split.upper(), epoch, torch.mean(tracker.elbo)))

            tracker._purge()
예제 #7
0
class ModularTrainer(Sampler):
    def __init__(self):
        self.config = load_config()
        self.model_config = self.config['Models']
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")s
        self.pretrain = True

        # Model
        self.task_type = self.config['Utils']['task_type']
        self.max_sequence_length = self.config['Utils'][self.task_type]['max_sequence_length']
        
        self.budget_frac = self.config['Train']['budget_frac']
        self.batch_size = self.config['Train']['batch_size']
        self.data_splits_frac = np.round(np.linspace(self.budget_frac, self.budget_frac*10, num=10, endpoint=True), 2)
        
        
        # Real data
        self.data_name = self.config['Utils'][self.task_type]['data_name']
        self.data_splits = self.config['Utils'][self.task_type]['data_split']
        self.pad_idx = self.config['Utils']['special_token2idx']['<PAD>']
        
        # Test run properties
        self.epochs = self.config['Train']['epochs']
        self.svae_iterations = self.config['Train']['svae_iterations']
        self.dsc_iterations = self.config['Train']['discriminator_iterations']
        self.adv_hyperparam = self.config['Models']['SVAE']['adversarial_hyperparameter']
        
    def _init_data(self, batch_size=None):
        kfold_xval = False
        
        self.x_y_pair_name = 'seq_label_pairs_enc' if self.data_name == 'ag_news' else 'seq_tags_pairs_enc' # Key in dataset - semantically correct for the task at hand.

        if batch_size is None:
            batch_size = self.config['Train']['batch_size']

        # Load pre-processed data
        path_data = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data', self.task_type, self.data_name, 'data.json')
        path_vocab = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data', self.task_type, self.data_name, 'vocabs.json')
        self.preprocess_data = load_json(path_data)
        self.vocab = load_json(path_vocab)       # Required for decoding sequences for interpretations. TODO: Find suitable location... or leave be...
        self.vocab_size = len(self.vocab['words'])  # word vocab is used for model dimensionality setting + includes special characters (EOS, SOS< UNK, PAD)
        self.tagset_size = len(self.vocab['tags'])  # this includes special characters (EOS, SOS, UNK, PAD)
        
        self.datasets = dict()
        if kfold_xval:
            # Perform k-fold cross-validation
            # Join all datasets and then randomly assign train/val/test
            print('Performing k-fold x-val')
            for split in self.data_splits:
                print(self.preprocess_data[split][self.x_y_pair_name])
            
        else:    
            for split in self.data_splits:
                # Access data
                split_data = self.preprocess_data[split][self.x_y_pair_name]
                # Convert lists of encoded sequences into tensors and stack into one large tensor
                split_seqs = torch.stack([torch.tensor(enc_pair[0]) for key, enc_pair in split_data.items()])
                split_tags = torch.stack([torch.tensor(enc_pair[1]) for key, enc_pair in split_data.items()])
                # Create torch dataset from tensors
                split_dataset = RealDataset(sequences=split_seqs, tags=split_tags)
                # Add to dictionary
                self.datasets[split] = split_dataset #split_dataloader
                
                # Create torch dataloader generator from dataset
                if split == 'test':
                    self.test_dataloader = DataLoader(dataset=split_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
                if split == 'valid':
                    self.val_dataloader = DataLoader(dataset=split_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

        print(f'{datetime.now()}: Data loaded succesfully')
        
        train_dataset = self.datasets['train']
        dataset_size = len(train_dataset)
        
        self.budget = math.ceil(self.budget_frac*dataset_size)
        Sampler.__init__(self, self.budget)
        
        all_indices = set(np.arange(dataset_size))
        k_initial = math.ceil(len(all_indices)*self.budget_frac)
        initial_indices = random.sample(list(all_indices), k=k_initial)
        
        sampler_init = torch.utils.data.sampler.SubsetRandomSampler(initial_indices)
        
        self.labelled_dataloader = DataLoader(train_dataset, sampler=sampler_init, batch_size=self.batch_size, drop_last=True)
        self.val_dataloader = DataLoader(self.datasets['valid'], batch_size=self.batch_size, shuffle=True, drop_last=False)
        self.test_dataloader = DataLoader(self.datasets['test'], batch_size=self.batch_size, shuffle=True, drop_last=False)

        print(f'{datetime.now()}: Dataloaders sizes: Train {len(self.labelled_dataloader)} Valid {len(self.val_dataloader)} Test {len(self.test_dataloader)}')
        return all_indices, initial_indices
    
    def _init_svae_model(self):
        self.svae = SVAE(**self.model_config['SVAE']['Parameters'],vocab_size=self.vocab_size).to(self.device)
        self.svae_optim = optim.Adam(self.svae.parameters(), lr=self.model_config['SVAE']['learning_rate'])
        if self.pretrain:
            print('Setting SVAE to EVAL mode')
            self.svae.eval()
        else:
            self.svae.train()
        print(f'{datetime.now()}: Initialised SVAE successfully')
    
    def _init_disc_model(self):
        self.discriminator = Discriminator(**self.model_config['Discriminator']['Parameters']).to(self.device)
        self.dsc_loss_fn = nn.BCELoss().to(self.device)
        self.dsc_optim = optim.Adam(self.discriminator.parameters(), lr=self.model_config['Discriminator']['learning_rate'])
        self.discriminator.train()
        print(f'{datetime.now()}: Initialised Discriminator successfully')
    
    def _init_gen_model(self):
        self.generator = Generator(**self.model_config['Generator']['Parameters']).to(self.device)
        self.gen_loss_fn = nn.BCELoss().to(self.device)
        self.gen_optim = optim.Adam(self.generator.parameters(), lr=self.model_config['Generator']['learning_rate'])
        self.generator.train()
        print(f'{datetime.now()}: Initialised Generator successfully')
    
    def _load_pretrained_model(self):
        # Initialise SVAE with saved parameters. TODO: Save model hyperparameters to disk with the saved weights
        self.svae = SVAE(**self.model_config['SVAE']['Parameters'], vocab_size=self.vocab_size).to(self.device)
        # Loads pre-trained SVAE model from disk and modifies
        svae_best_model_path = 'best models/svae.pt'
        self.svae.load_state_dict(torch.load(svae_best_model_path))
        print(f'{datetime.now()}: Loaded pretrained SVAE\n{self.svae}')
        
        class Identity(nn.Module):
            def __init__(self):
                super(Identity, self).__init__()
            def forward(self, x):
                return x
        
        self.svae.outputs2vocab=Identity()  # Removes hidden2output layer
        print(f'{datetime.now()}: Modified pretrained SVAE\n{self.svae}')
        
    def _disc_train(self, sequences_l, lengths_l, sequences_u, lengths_u):
        # Train discriminator
        batch_size_l = sequences_l.size(0)
        batch_size_u = sequences_u.size(0)

        # Pass through pretrained svae
        with torch.no_grad():
            z_l = self.svae(sequences_l, lengths_l, pretrain=True)
            z_u = self.svae(sequences_u, lengths_u, pretrain=True)

        # Train discriminator on labelled samples
        dsc_preds_l = self.discriminator(z_l)
        dsc_real_l = torch.ones_like(dsc_preds_l).to(self.device)

        # Train discriminator on unlabelled samples
        dsc_preds_u = self.discriminator(z_u)
        dsc_real_u = torch.zeros_like(dsc_preds_u).to(self.device)
        
        if torch.cuda.is_available():
            dsc_real_l = dsc_real_l.to(self.device)
            dsc_real_u = dsc_real_u.to(self.device)
        
        dsc_loss_l = self.dsc_loss_fn(dsc_preds_l, dsc_real_l) / batch_size_l
        dsc_loss_u = self.dsc_loss_fn(dsc_preds_u, dsc_real_u) / batch_size_u

        # Discriminator wants to minimise the loss here
        total_dsc_loss = dsc_loss_l + dsc_loss_u
        self.discriminator.zero_grad()
        self.dsc_optim.zero_grad()
        total_dsc_loss.backward()
        self.dsc_optim.step()
        
        return total_dsc_loss.data.item()
        
    def _gen_train(self, sequences_l, lengths_l, sequences_u, lengths_u):
        # Train Generator
        batch_size_l = sequences_l.size(0)
        batch_size_u = sequences_u.size(0)
        
        with torch.no_grad():
            z_l = self.svae(sequences_l, lengths_l, pretrain=True)
            z_u = self.svae(sequences_u, lengths_u, pretrain=True)
        
        # Adversarial loss - trying to fool the discriminator!
        gen_preds_l = self.discriminator(z_l)
        gen_preds_u = self.discriminator(z_u)
        gen_real_l = torch.ones_like(gen_preds_l)
        gen_real_u = torch.ones_like(gen_preds_u)

        if torch.cuda.is_available():
            gen_real_l = gen_real_l.to(self.device)
            gen_real_u = gen_real_u.to(self.device)

        # Higher loss = discriminator is having trouble figuring out the real vs fake
        # Generator wants to maximise this loss
        gen_loss_l = self.gen_loss_fn(gen_preds_l, gen_real_l) / batch_size_l
        gen_loss_u = self.gen_loss_fn(gen_preds_u, gen_real_u) / batch_size_u

        total_gen_loss = gen_loss_l + gen_loss_u
        self.generator.zero_grad()
        self.gen_optim.zero_grad()
        total_gen_loss.backward()
        self.gen_optim.step()
        
        return total_gen_loss.data.item()
    
    def _train_svaal_pretrained(self):
        # Trains SVAAL using pretrained SVAE and adversarial training routine
        
        all_sampled_indices_dict = dict()
        all_indices, initial_indices = self._init_data()
        current_indices = list(initial_indices)
        
        self._load_pretrained_model()   # Load pretrained SVAE
                
        print(f'{datetime.now()}: Split regime: {self.data_splits_frac}')
        for split in self.data_splits_frac:
            if split == 1:
                # Break if dataset is 100% as there will be no unlabelled data
                print('Exiting training')
                break
            
            print(f'{datetime.now()}: Running {split*100:0.0f}% of training dataset')
            meta = f" adv train {str(split*100)} "
            tb_writer = SummaryWriter(comment=meta, filename_suffix=meta)

            # Initialise discriminator and generator models for training
            self._init_disc_model()
            self._init_gen_model()
            
            unlabelled_indices = np.setdiff1d(list(all_indices), current_indices)
            unlabelled_sampler = data.sampler.SubsetRandomSampler(unlabelled_indices)
            unlabelled_dataloader = data.DataLoader(self.datasets['train'],
                                                    sampler=unlabelled_sampler,
                                                    batch_size=self.config['Train']['batch_size'],
                                                    drop_last=False)
            
            print(f'{datetime.now()}: Indice Counts - Labelled {len(current_indices)} Unlabelled {len(unlabelled_indices)} Total {len(all_indices)}')
            
            # Save indices of X_l, X_u
            all_sampled_indices_dict[str(int(split*100))] = {'Labelled': current_indices, 'Unlabelled': unlabelled_indices}
            # print(all_sampled_indices_dict)
            
            dataloader_l = self.labelled_dataloader
            dataloader_u = unlabelled_dataloader
            dataset_size = len(dataloader_l) + len(dataloader_u)
            train_iterations = dataset_size * self.epochs
            
            print(f'{datetime.now()}: Dataset size (batches) {dataset_size} Training iterations (batches) {train_iterations}')

            epoch = 1
            step = 1
            for train_iter in tqdm(range(train_iterations), desc='Training iteration'):
                batch_sequences_l, batch_lengths_l, _ = next(iter(dataloader_l))
                batch_sequences_u, batch_lengths_u, _ = next(iter(dataloader_u))

                if torch.cuda.is_available():
                    batch_sequences_l = batch_sequences_l.to(self.device)
                    batch_lengths_l = batch_lengths_l.to(self.device)
                    batch_sequences_u = batch_sequences_u.to(self.device)
                    batch_length_u = batch_lengths_u.to(self.device)

                # Discriminator
                disc_loss = self._disc_train(sequences_l=batch_sequences_l,
                                             lengths_l=batch_lengths_l,
                                             sequences_u=batch_sequences_u,
                                             lengths_u=batch_length_u)
                # Generator
                gen_loss = self._gen_train(sequences_l=batch_sequences_l,
                                           lengths_l=batch_lengths_l,
                                           sequences_u=batch_sequences_u,
                                           lengths_u=batch_length_u)
            
                tb_writer.add_scalars("Loss/Train",
                                     {'Discriminator': disc_loss,
                                      'Generator': gen_loss},
                                     step)

                # if (train_iter > 0) & (train_iter % dataset_size == 0):
                #     train_iter_str = f'{datetime.now()}: Epoch {epoch} - Losses ({self.task_type}) | Disc {disc_loss:0.2f} | Gen {gen_loss:0.2f} | Learning rates: ...'
                #     print(train_iter_str)
                #     epoch += 1
                
                step += 1
            
            # Adversarially sample from unlabelled pool
            sampled_indices, preds_topk, _ = self.sample_adversarial(svae=self.svae,
                                                                     discriminator=self.discriminator,
                                                                     data=dataloader_u,
                                                                     indices=unlabelled_indices,
                                                                     pretrain=True)
            
            # Update indices -> Update dataloaders
            current_indices = list(current_indices) + list(sampled_indices)
            sampler = torch.utils.data.sampler.SubsetRandomSampler(current_indices)
            self.labelled_dataloader = DataLoader(self.datasets['train'], sampler=sampler, batch_size=self.batch_size, drop_last=True)

            # Save sampled data
            try:
                path = os.path.join(os.getcwd(), 'results', str(int(split*100)))
                if os.path.exists(path):
                    pass
                else:
                    os.mkdir(path)

                # torch.save(self.labelled_dataloader, os.path.join(path, 'labelled_data.pth'))
                
                # Save adversarial predictions
                preds_topk = "\n".join([str(pred.item()) for pred in preds_topk])
                with open('preds_topk.txt', 'w') as fw:
                    fw.writeline(preds_topk)

                # Save sampled training data (this is reconstructed from the sampled indices)
                output_str = ''
                for i in sampled_indices:
                    sample = self.preprocess_data['train']['seq_tags_pairs'][str(i)]
                    seq = sample[0]
                    tags = sample[1]
                    temp_str = ''
                    for idx, token in enumerate(seq):
                        if token == '<START>':
                            pass
                        elif token == '<STOP>':
                            break
                        else:
                            temp_str += seq[idx] + ' x x ' + tags[idx] + '\n'
                    output_str += temp_str + '\n'
                with open(os.path.join(path, 'train.txt'), 'w') as fw:
                    fw.write(output_str)
                    

                # Save test/valid (dev) sets for local outputs
                for split_name in ['test', 'valid']:
                    output_str = ''
                    for i, pair in self.preprocess_data[split_name][self.x_y_pair_name].items():
                        seq, tags = pair
                        temp_str = ''
                        for idx, token in enumerate(seq):
                            if token == '<START>':
                                pass
                            elif token == '<STOP>':
                                break
                            else:
                                temp_str += seq[idx] + ' x x ' + tags[idx] + '\n'
                    
                    with open(os.path.join(path, f'{split_name}.txt'), 'w') as fw:
                        fw.write(output_str)
            
                # # Reconstructing test/valid for local output
                # output_str = ''
                # for i, pair in self.preprocess_data['test']["seq_tags_pairs"].items():      # TODO: Fix hard coded seq tags pairs...
                #     seq_test, tags_test = pair
            
                #     temp_str = ''
                #     for idx, token in enumerate(seq_test):
                #         if token == '<START>':
                #             pass
                #         elif token == '<STOP>':
                #             break
                #         else:
                #             temp_str += seq_test[idx] + ' x x ' + tags_test[idx] + '\n'
            
                #     output_str += temp_str + '\n'
                
                # with open(os.path.join(path, 'test.txt'), 'w') as fw:
                #     fw.write(output_str)
                    
                # output_str = ''
                # for i, pair in self.preprocess_data['valid']["seq_tags_pairs"].items():      # TODO: Fix hard coded seq tags pairs...
                #     seq_test, tags_test = pair
            
                #     temp_str = ''
                #     for idx, token in enumerate(seq_test):
                #         if token == '<START>':
                #             pass
                #         elif token == '<STOP>':
                #             break
                #         else:
                #             # add x and x as placeholders for pos and nn tags for CoNLL
                #             temp_str += seq_test[idx] + ' x x ' + tags_test[idx] + '\n'
            
                #     output_str += temp_str + '\n'
                
                # with open(os.path.join(path, 'valid.txt'), 'w') as fw:
                #     fw.write(output_str)
                    
            except:
                print('Path for dataloader save failed')
                traceback.print_exc(file=sys.stdout)
            
            
    def _train_svaal(self):
        
        all_sampled_indices_dict = dict()
        
        all_indices, initial_indices = self._init_data()
        current_indices = list(initial_indices)
        
        print(f'{datetime.now()}: Split regime: {self.data_splits_frac}')
        for split in self.data_splits_frac:
            print(f'{datetime.now()}: Running {split*100:0.0f}% of training dataset')

            self._init_svae_model()
            self._init_disc_model()
            
            unlabelled_indices = np.setdiff1d(list(all_indices), current_indices)
            unlabelled_sampler = data.sampler.SubsetRandomSampler(unlabelled_indices)
            unlabelled_dataloader = data.DataLoader(self.datasets['train'],
                                                    sampler=unlabelled_sampler,
                                                    batch_size=self.config['Train']['batch_size'],
                                                    drop_last=False)
            print(f'{datetime.now()}: Indices - Labelled {len(current_indices)} Unlabelled {len(unlabelled_indices)} Total {len(all_indices)}')
            # Save indices of X_l, X_u
            all_sampled_indices_dict[str(int(split*100))] = {'Labelled': current_indices, 'Unlabelled': unlabelled_indices}
            # print(all_sampled_indices_dict)
            
            
            dataloader_l = self.labelled_dataloader
            dataloader_u = unlabelled_dataloader
            
            
            dataset_size = len(dataloader_l) + len(dataloader_u)
            train_iterations = dataset_size * self.epochs
            print(f'{datetime.now()}: Dataset size (batches) {dataset_size} Training iterations (batches) {train_iterations}')

            step = 0
            epoch = 1      
            for train_iter in tqdm(range(train_iterations), desc='Training iteration'):
                batch_sequences_l, batch_lengths_l, _ = next(iter(dataloader_l))
                batch_sequences_u, batch_lengths_u, _ = next(iter(dataloader_u))

                if torch.cuda.is_available():
                    batch_sequences_l = batch_sequences_l.to(self.device)
                    batch_lengths_l = batch_lengths_l.to(self.device)
                    batch_sequences_u = batch_sequences_u.to(self.device)
                    batch_length_u = batch_lengths_u.to(self.device)

                batch_size_l = batch_sequences_l.size(0)
                batch_size_u = batch_sequences_u.size(0)

                # SVAE Step
                for i in range(self.svae_iterations):
                    logp_l, mean_l, logv_l, z_l = self.svae(batch_sequences_l, batch_lengths_l)
                    NLL_loss_l, KL_loss_l, KL_weight_l = self.svae.loss_fn(
                                                                    logp=logp_l,
                                                                    target=batch_sequences_l,
                                                                    length=batch_lengths_l,
                                                                    mean=mean_l,
                                                                    logv=logv_l,
                                                                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                                                                    step=step,
                                                                    k=self.model_config['SVAE']['k'],
                                                                    x0=self.model_config['SVAE']['x0'])

                    logp_u, mean_u, logv_u, z_u = self.svae(batch_sequences_u, batch_lengths_u)
                    NLL_loss_u, KL_loss_u, KL_weight_u = self.svae.loss_fn(
                                                                    logp=logp_u,
                                                                    target=batch_sequences_u,
                                                                    length=batch_lengths_u,
                                                                    mean=mean_u,
                                                                    logv=logv_u,
                                                                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                                                                    step=step,
                                                                    k=self.model_config['SVAE']['k'],
                                                                    x0=self.model_config['SVAE']['x0'])
                    # VAE loss
                    svae_loss_l = (NLL_loss_l + KL_weight_l * KL_loss_l) / batch_size_l
                    svae_loss_u = (NLL_loss_u + KL_weight_u * KL_loss_u) / batch_size_u

                    # Adversarial loss - trying to fool the discriminator!
                    dsc_preds_l = self.discriminator(z_l)   # mean_l
                    dsc_preds_u = self.discriminator(z_u)   # mean_u
                    dsc_real_l = torch.ones(batch_size_l)
                    dsc_real_u = torch.ones(batch_size_u)

                    if torch.cuda.is_available():
                        dsc_real_l = dsc_real_l.to(self.device)
                        dsc_real_u = dsc_real_u.to(self.device)

                    # Higher loss = discriminator is having trouble figuring out the real vs fake
                    # Generator wants to maximise this loss
                    adv_dsc_loss_l = self.dsc_loss_fn(dsc_preds_l, dsc_real_l)
                    adv_dsc_loss_u = self.dsc_loss_fn(dsc_preds_u, dsc_real_u)
                    adv_dsc_loss = adv_dsc_loss_l + adv_dsc_loss_u

                    total_svae_loss = svae_loss_u + svae_loss_l + self.adv_hyperparam * adv_dsc_loss
                    self.svae_optim.zero_grad()
                    total_svae_loss.backward()
                    self.svae_optim.step()

                    # Sample new batch of data while training adversarial network
                    if i < self.svae_iterations - 1:
                        batch_sequences_l, batch_lengths_l, _ =  next(iter(dataloader_l))
                        batch_sequences_u, batch_length_u, _ = next(iter(dataloader_u))

                        if torch.cuda.is_available():
                            batch_sequences_l = batch_sequences_l.to(self.device)
                            batch_lengths_l = batch_lengths_l.to(self.device)
                            batch_sequences_u = batch_sequences_u.to(self.device)
                            batch_length_u = batch_length_u.to(self.device)
                    
                    # Increment step
                    step += 1

                # SVAE train_iter loss after iterative cycle
                # self.tb_writer.add_scalar('Loss/SVAE/train/Total', total_svae_loss, train_iter)

                # Discriminator Step
                for j in range(self.dsc_iterations):

                    with torch.no_grad():
                        _, mean_l, _, z_l = self.svae(batch_sequences_l, batch_lengths_l)
                        _, mean_u, _, z_u = self.svae(batch_sequences_u, batch_lengths_u)

                    dsc_preds_l = self.discriminator(z_l)
                    dsc_preds_u = self.discriminator(z_u)

                    dsc_real_l = torch.ones(batch_size_l)
                    dsc_real_u = torch.zeros(batch_size_u)

                    if torch.cuda.is_available():
                        dsc_real_l = dsc_real_l.to(self.device)
                        dsc_real_u = dsc_real_u.to(self.device)

                    # Discriminator wants to minimise the loss here
                    dsc_loss_l = self.dsc_loss_fn(dsc_preds_l, dsc_real_l)
                    dsc_loss_u = self.dsc_loss_fn(dsc_preds_u, dsc_real_u)
                    total_dsc_loss = dsc_loss_l + dsc_loss_u
                    self.dsc_optim.zero_grad()
                    total_dsc_loss.backward()
                    self.dsc_optim.step()

                    # Sample new batch of data while training adversarial network
                    if j < self.dsc_iterations - 1:
                        # TODO: strip out unnecessary information
                        batch_sequences_l, batch_lengths_l, _ =  next(iter(dataloader_l))
                        batch_sequences_u, batch_length_u, _ = next(iter(dataloader_u))

                        if torch.cuda.is_available():
                            batch_sequences_l = batch_sequences_l.to(self.device)
                            batch_lengths_l = batch_lengths_l.to(self.device)
                            batch_sequences_u = batch_sequences_u.to(self.device)
                            batch_length_u = batch_length_u.to(self.device)
            
                if (train_iter >0) & (train_iter % dataset_size == 0):
                    train_iter_str = f'{datetime.now()}: Epoch {epoch} - Losses ({self.task_type}) | SVAE {total_svae_loss:0.2f} | Disc {total_dsc_loss:0.2f} | Learning rates: '
                    print(train_iter_str)
                    
                    epoch += 1
                    
            
            # Adversarial sample
            sampled_indices, _, _ = self.sample_adversarial(svae=self.svae,
                                                      discriminator=self.discriminator,
                                                      data=dataloader_u,
                                                      indices=unlabelled_indices,
                                                      cuda=True)
            
            # Update indices -> Update dataloaders
            current_indices = list(current_indices) + list(sampled_indices)
            sampler = torch.utils.data.sampler.SubsetRandomSampler(current_indices)
            self.labelled_dataloader = DataLoader(self.datasets['train'], sampler=sampler, batch_size=self.batch_size, drop_last=True)
    
    
    def _init_tl_model(self):
        """ Initialises models, loss functions, optimisers and sets models to training mode """
        self.task_learner = TaskLearner(**self.model_config['TaskLearner']['Parameters'],
                                        vocab_size=self.vocab_size,
                                        tagset_size=self.tagset_size,
                                        task_type=self.task_type).to(self.device)
        if self.task_type == 'SEQ':
            self.tl_loss_fn = nn.NLLLoss().to(self.device)
        if self.task_type == 'CLF':
            self.tl_loss_fn = nn.CrossEntropyLoss().to(self.device)

        self.tl_optim = optim.SGD(self.task_learner.parameters(), lr=self.model_config['TaskLearner']['learning_rate'])#, momentum=0, weight_decay=0.1)
        
        # Learning rate scheduler
        # Note: LR likely GT Adam
        # self.tl_sched = optim.lr_scheduler.ReduceLROnPlateau(self.tl_optim, 'min', factor=0.5, patience=10)
        # Training Modes
        self.task_learner.train()
    
        print(f'{datetime.now()}: Initialised Task Learner successfully')
    
    def _train_tl(self, dataloader_l):
        
        self.init_tl_model()
        
        train_iterations = len(dataloader_l) * (self.epochs+1)
        
        for train_iter in tqdm(range(train_iterations), desc='Training iteration'):
            batch_sequences_l, batch_lengths_l, batch_tags_l =  next(iter(dataloader_l))

            if torch.cuda.is_available():
                batch_sequences_l = batch_sequences_l.to(self.device)
                batch_lengths_l = batch_lengths_l.to(self.device)
                batch_tags_l = batch_tags_l.to(self.device)
            
            # Strip off tag padding and flatten
            # Don't do sequences here as its done in the forward pass of the seq2seq models
            batch_tags_l = trim_padded_seqs(batch_lengths=batch_lengths_l,
                                            batch_sequences=batch_tags_l,
                                            pad_idx=self.pad_idx).view(-1)

            # Task Learner Step
            self.tl_optim.zero_grad()
            tl_preds = self.task_learner(batch_sequences_l, batch_lengths_l)
            tl_loss = self.tl_loss_fn(tl_preds, batch_tags_l)
            tl_loss.backward()
            self.tl_optim.step()
예제 #8
0
class ModularTrainer(Sampler):
    def __init__(self):
        self.config = load_config()
        self.model_config = self.config['Models']

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # Model
        self.task_type = self.config['Utils']['task_type']
        self.max_sequence_length = self.config['Utils'][
            self.task_type]['max_sequence_length']

        # Real data
        self.data_name = self.config['Utils'][self.task_type]['data_name']
        self.data_splits = self.config['Utils'][self.task_type]['data_split']
        self.pad_idx = self.config['Utils']['special_token2idx']['<PAD>']

        # Test run properties
        self.epochs = self.config['Train']['epochs']
        self.svae_iterations = self.config['Train']['svae_iterations']

        self.kfold_xval = False

    def _init_data(self, batch_size=None):
        if batch_size is None:
            batch_size = self.config['Train']['batch_size']

        # Load pre-processed data
        path_data = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data',
                                 self.task_type, self.data_name, 'pretrain',
                                 'data.json')
        path_vocab = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data',
                                  self.task_type, self.data_name, 'pretrain',
                                  'vocab.json')  # not vocabs
        data = load_json(path_data)

        self.vocab = load_json(
            path_vocab
        )  # Required for decoding sequences for interpretations. TODO: Find suitable location... or leave be...
        self.vocab_size = len(self.vocab['word2idx'])

        self.idx2word = self.vocab['idx2word']
        self.word2idx = self.vocab['word2idx']

        self.datasets = dict()
        if self.kfold_xval:
            # Perform k-fold cross-validation
            # Join all datasets and then randomly assign train/val/test
            print('hello')

            for split in self.data_splits:
                print(data[split][self.x_y_pair_name])

        else:
            for split in self.data_splits:
                # Access data
                split_data = data[split]
                # print(split_data)
                # Convert lists of encoded sequences into tensors and stack into one large tensor
                split_inputs = torch.stack([
                    torch.tensor(value['input'])
                    for key, value in split_data.items()
                ])
                split_targets = torch.stack([
                    torch.tensor(value['target'])
                    for key, value in split_data.items()
                ])
                # Create torch dataset from tensors
                split_dataset = RealDataset(sequences=split_inputs,
                                            tags=split_targets)
                # Add to dictionary
                self.datasets[split] = split_dataset  #split_dataloader

                # Create torch dataloader generator from dataset
                if split == 'test':
                    self.test_dataloader = DataLoader(dataset=split_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      num_workers=0)
                if split == 'valid':
                    self.val_dataloader = DataLoader(dataset=split_dataset,
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=0)
                if split == 'test':
                    self.train_dataloader = DataLoader(dataset=split_dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=0)

        print(f'{datetime.now()}: Data loaded succesfully')

    def _init_svae_model(self):
        self.svae = SVAE(**self.model_config['SVAE']['Parameters'],
                         vocab_size=self.vocab_size).to(self.device)
        self.svae_optim = optim.Adam(
            self.svae.parameters(),
            lr=self.model_config['SVAE']['learning_rate'])
        self.svae.train()
        print(f'{datetime.now()}: Initialised SVAE successfully')

    def interpolate(self, start, end, steps):

        interpolation = np.zeros((start.shape[0], steps + 2))

        for dim, (s, e) in enumerate(zip(start, end)):
            interpolation[dim] = np.linspace(s, e, steps + 2)

        return interpolation.T

    def _idx2word_inf(self, idx, i2w, pad_idx):
        # inf-erence
        sent_str = [str()] * len(idx)

        for i, sent in enumerate(idx):
            for word_id in sent:
                if word_id == pad_idx:
                    break

                sent_str[i] += i2w[str(word_id.item())] + " "
            sent_str[i] = sent_str[i].strip()
        return sent_str

    def _pretrain_svae(self):
        self._init_data()
        self._init_svae_model()

        tb_writer = SummaryWriter(
            comment=f"pretrain svae {self.data_name}",
            filename_suffix=f"pretrain svae {self.data_name}")
        print(f'{datetime.now()}: Training started')

        step = 0
        for epoch in range(1, self.config['Train']['epochs'] + 1, 1):
            for batch_inputs, batch_lengths, batch_targets in self.train_dataloader:
                if torch.cuda.is_available():
                    batch_inputs = batch_inputs.to(self.device)
                    batch_lengths = batch_lengths.to(self.device)
                    batch_targets = batch_targets.to(self.device)

                batch_size = batch_inputs.size(0)
                logp, mean, logv, _ = self.svae(batch_inputs,
                                                batch_lengths,
                                                pretrain=False)
                NLL_loss, KL_loss, KL_weight = self.svae.loss_fn(
                    logp=logp,
                    target=batch_targets,
                    length=batch_lengths,
                    mean=mean,
                    logv=logv,
                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                    step=step,
                    k=self.model_config['SVAE']['k'],
                    x0=self.model_config['SVAE']['x0'])
                svae_loss = (NLL_loss + KL_weight * KL_loss) / batch_size
                self.svae_optim.zero_grad()
                svae_loss.backward()
                self.svae_optim.step()

                tb_writer.add_scalar('Loss/train/KLL', KL_loss, step)
                tb_writer.add_scalar('Loss/train/NLL', NLL_loss, step)
                tb_writer.add_scalar('Loss/train/Total', svae_loss, step)
                tb_writer.add_scalar('Utils/train/KL_weight', KL_weight, step)

                # Increment step after each batch of data
                step += 1

            if epoch % 1 == 0:
                print(
                    f'{datetime.now()}: Epoch {epoch} Loss {svae_loss:0.2f} Step {step}'
                )

            if epoch % 5 == 0:
                # Perform inference
                self.svae.eval()
                try:
                    samples, z = self.svae.inference(n=2)
                    print(*self._idx2word_inf(samples,
                                              i2w=self.idx2word,
                                              pad_idx=self.config['Utils']
                                              ['special_token2idx']['<PAD>']),
                          sep='\n')
                except:
                    traceback.print_exc(file=sys.stdout)
                self.svae.train()

        # Save final model
        save_path = os.getcwd() + '/best models/svae.pt'
        torch.save(self.svae.state_dict(), save_path)
        print(f'{datetime.now()}: Model saved')

        print(f'{datetime.now()}: Training finished')
예제 #9
0
    def train_single_cycle(self, parameterisation=None):
        """ Performs a single cycle of the active learning routine at a specified data split for optimisation purposes"""
        print(parameterisation)

        if parameterisation is None:
            params = {
                'tl': self.model_config['TaskLearner']['Parameters'],
                'svae': self.model_config['SVAE']['Parameters'],
                'disc': self.model_config['Discriminator']['Parameters']
            }
            params.update({
                'tl_learning_rate':
                self.model_config['TaskLearner']['learning_rate']
            })
            params.update({
                'svae_learning_rate':
                self.model_config['SVAE']['learning_rate']
            })
            params.update({
                'disc_learning_rate':
                self.model_config['Discriminator']['learning_rate']
            })
            params.update({'epochs': self.config['Train']['epochs']})
            params.update({'k': self.model_config['SVAE']['k']})
            params.update({'x0': self.model_config['SVAE']['x0']})
            params.update({
                'adv_hyperparameter':
                self.model_config['SVAE']['adversarial_hyperparameter']
            })

        if parameterisation:
            params = {
                'epochs':
                parameterisation['epochs'],
                'batch_size':
                parameterisation['batch_size'],
                'tl_learning_rate':
                self.model_config['TaskLearner']
                ['learning_rate'],  #parameterisation['tl_learning_rate'],
                'svae_learning_rate':
                parameterisation['svae_learning_rate'],
                'disc_learning_rate':
                parameterisation['disc_learning_rate'],
                'k':
                parameterisation['svae_k'],
                'x0':
                parameterisation['svae_x0'],
                'adv_hyperparameter':
                self.model_config['SVAE']
                ['adversarial_hyperparameter'],  #parameterisation['svae_adv_hyperparameter'], 
                'tl':
                self.model_config['TaskLearner']['Parameters'],
                #{'embedding_dim': parameterisation['tl_embedding_dim'],
                # 'hidden_dim': parameterisation['tl_hidden_dim'],
                # 'rnn_type': parameterisation['tl_rnn_type']},
                'svae': {
                    'embedding_dim':
                    parameterisation['svae_embedding_dim'],
                    'hidden_dim':
                    parameterisation['svae_hidden_dim'],
                    'word_dropout':
                    parameterisation['svae_word_dropout'],
                    'embedding_dropout':
                    parameterisation['svae_embedding_dropout'],
                    'num_layers':
                    self.model_config['SVAE']['Parameters']
                    ['num_layers'],  #parameterisation['svae_num_layers'],
                    'bidirectional':
                    self.model_config['SVAE']['Parameters']
                    ['bidirectional'],  #parameterisation['svae_bidirectional'],
                    'rnn_type':
                    self.model_config['SVAE']['Parameters']
                    ['rnn_type'],  #parameterisation['svae_rnn_type'],
                    'latent_size':
                    parameterisation['latent_size']
                },
                'disc': {
                    'z_dim': parameterisation['latent_size'],
                    'fc_dim': parameterisation['disc_fc_dim']
                }
            }

        split = self.config['Train']['cycle_frac']
        print(f'\n{datetime.now()}\nSplit: {split*100:0.0f}%')
        meta = f' {self.al_mode} run x data split {split*100:0.0f}'

        self._init_dataset()
        train_dataset = self.datasets['train']
        dataset_size = len(train_dataset)
        self.budget = math.ceil(
            self.budget_frac *
            dataset_size)  # currently can only have a fixed budget size
        Sampler.__init__(self, self.budget)

        all_indices = set(
            np.arange(dataset_size))  # indices of all samples in train
        k_initial = math.ceil(
            len(all_indices) *
            split)  # number of initial samples given split size
        initial_indices = random.sample(
            list(all_indices),
            k=k_initial)  # random sample of initial indices from train
        sampler_init = data.sampler.SubsetRandomSampler(
            initial_indices
        )  # sampler method for dataloader to randomly sample initial indices
        current_indices = list(
            initial_indices)  # current set of labelled indices

        dataloader_l = data.DataLoader(train_dataset,
                                       sampler=sampler_init,
                                       batch_size=params['batch_size'],
                                       drop_last=True)
        dataloader_v = data.DataLoader(self.datasets['valid'],
                                       batch_size=params['batch_size'],
                                       shuffle=True,
                                       drop_last=False)
        dataloader_t = data.DataLoader(self.datasets['test'],
                                       batch_size=params['batch_size'],
                                       shuffle=True,
                                       drop_last=False)

        unlabelled_indices = np.setdiff1d(
            list(all_indices), current_indices
        )  # set of unlabelled indices (all - initial/current)
        unlabelled_sampler = data.sampler.SubsetRandomSampler(
            unlabelled_indices
        )  # sampler method for dataloader to randomly sample unlabelled indices
        dataloader_u = data.DataLoader(self.datasets['train'],
                                       sampler=unlabelled_sampler,
                                       batch_size=params['batch_size'],
                                       drop_last=False)
        print(
            f'{datetime.now()}: Indices - Labelled {len(current_indices)} Unlabelled {len(unlabelled_indices)} Total {len(all_indices)}'
        )

        # Initialise models
        task_learner = TaskLearner(**params['tl'],
                                   vocab_size=self.vocab_size,
                                   tagset_size=self.tagset_size,
                                   task_type=self.task_type).to(self.device)
        if self.task_type == 'SEQ':
            tl_loss_fn = nn.NLLLoss().to(self.device)
        if self.task_type == 'CLF':
            tl_loss_fn = nn.CrossEntropyLoss().to(self.device)

        tl_optim = optim.SGD(
            task_learner.parameters(),
            lr=params['tl_learning_rate'])  #, momentum=0, weight_decay=0.1)
        task_learner.train()

        svae = SVAE(**params['svae'],
                    vocab_size=self.vocab_size).to(self.device)
        discriminator = Discriminator(**params['disc']).to(self.device)
        dsc_loss_fn = nn.BCELoss().to(self.device)
        svae_optim = optim.Adam(svae.parameters(),
                                lr=params['svae_learning_rate'])
        dsc_optim = optim.Adam(discriminator.parameters(),
                               lr=params['disc_learning_rate'])

        # Training Modes
        svae.train()
        discriminator.train()

        print(f'{datetime.now()}: Models initialised successfully')

        # Perform AL training and sampling
        early_stopping = EarlyStopping(
            patience=self.config['Train']['es_patience'],
            verbose=True,
            path="checkpoints/checkpoint.pt")

        dataset_size = len(dataloader_l) + len(
            dataloader_u) if dataloader_u is not None else len(dataloader_l)
        train_iterations = dataset_size * (params['epochs'] + 1)

        print(
            f'{datetime.now()}: Dataset size (batches) {dataset_size} Training iterations (batches) {train_iterations}'
        )

        step = 0
        epoch = 1
        for train_iter in tqdm(range(train_iterations),
                               desc='Training iteration'):
            batch_sequences_l, batch_lengths_l, batch_tags_l = next(
                iter(dataloader_l))

            if torch.cuda.is_available():
                batch_sequences_l = batch_sequences_l.to(self.device)
                batch_lengths_l = batch_lengths_l.to(self.device)
                batch_tags_l = batch_tags_l.to(self.device)

            if dataloader_u is not None:
                batch_sequences_u, batch_lengths_u, _ = next(
                    iter(dataloader_u))
                batch_sequences_u = batch_sequences_u.to(self.device)
                batch_length_u = batch_lengths_u.to(self.device)

            # Strip off tag padding and flatten
            # Don't do sequences here as its done in the forward pass of the seq2seq models
            batch_tags_l = trim_padded_seqs(batch_lengths=batch_lengths_l,
                                            batch_sequences=batch_tags_l,
                                            pad_idx=self.pad_idx).view(-1)

            # Task Learner Step
            tl_optim.zero_grad()
            tl_preds = task_learner(batch_sequences_l, batch_lengths_l)
            tl_loss = tl_loss_fn(tl_preds, batch_tags_l)
            tl_loss.backward()
            tl_optim.step()

            # Used in SVAE and Discriminator
            batch_size_l = batch_sequences_l.size(0)
            batch_size_u = batch_sequences_u.size(0)

            # SVAE Step
            for i in range(self.svae_iterations):
                logp_l, mean_l, logv_l, z_l = svae(batch_sequences_l,
                                                   batch_lengths_l)
                NLL_loss_l, KL_loss_l, KL_weight_l = svae.loss_fn(
                    logp=logp_l,
                    target=batch_sequences_l,
                    length=batch_lengths_l,
                    mean=mean_l,
                    logv=logv_l,
                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                    step=step,
                    k=params['k'],
                    x0=params['x0'])

                logp_u, mean_u, logv_u, z_u = svae(batch_sequences_u,
                                                   batch_lengths_u)
                NLL_loss_u, KL_loss_u, KL_weight_u = svae.loss_fn(
                    logp=logp_u,
                    target=batch_sequences_u,
                    length=batch_lengths_u,
                    mean=mean_u,
                    logv=logv_u,
                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                    step=step,
                    k=params['k'],
                    x0=params['x0'])
                # VAE loss
                svae_loss_l = (NLL_loss_l +
                               KL_weight_l * KL_loss_l) / batch_size_l
                svae_loss_u = (NLL_loss_u +
                               KL_weight_u * KL_loss_u) / batch_size_u

                # Adversary loss - trying to fool the discriminator!
                dsc_preds_l = discriminator(z_l)  # mean_l
                dsc_preds_u = discriminator(z_u)  # mean_u
                dsc_real_l = torch.ones(batch_size_l)
                dsc_real_u = torch.ones(batch_size_u)

                if torch.cuda.is_available():
                    dsc_real_l = dsc_real_l.to(self.device)
                    dsc_real_u = dsc_real_u.to(self.device)

                adv_dsc_loss_l = dsc_loss_fn(dsc_preds_l, dsc_real_l)
                adv_dsc_loss_u = dsc_loss_fn(dsc_preds_u, dsc_real_u)
                adv_dsc_loss = adv_dsc_loss_l + adv_dsc_loss_u

                total_svae_loss = svae_loss_u + svae_loss_l + params[
                    'adv_hyperparameter'] * adv_dsc_loss
                svae_optim.zero_grad()
                total_svae_loss.backward()
                svae_optim.step()

                if i < self.svae_iterations - 1:
                    batch_sequences_l, batch_lengths_l, _ = next(
                        iter(dataloader_l))
                    batch_sequences_u, batch_length_u, _ = next(
                        iter(dataloader_u))

                    if torch.cuda.is_available():
                        batch_sequences_l = batch_sequences_l.to(self.device)
                        batch_lengths_l = batch_lengths_l.to(self.device)
                        batch_sequences_u = batch_sequences_u.to(self.device)
                        batch_length_u = batch_length_u.to(self.device)

                step += 1

            # Discriminator Step
            for j in range(self.dsc_iterations):

                with torch.no_grad():
                    _, _, _, z_l = svae(batch_sequences_l, batch_lengths_l)
                    _, _, _, z_u = svae(batch_sequences_u, batch_lengths_u)

                dsc_preds_l = discriminator(z_l)
                dsc_preds_u = discriminator(z_u)

                dsc_real_l = torch.ones(batch_size_l)
                dsc_real_u = torch.zeros(batch_size_u)

                if torch.cuda.is_available():
                    dsc_real_l = dsc_real_l.to(self.device)
                    dsc_real_u = dsc_real_u.to(self.device)

                # Discriminator wants to minimise the loss here
                dsc_loss_l = dsc_loss_fn(dsc_preds_l, dsc_real_l)
                dsc_loss_u = dsc_loss_fn(dsc_preds_u, dsc_real_u)
                total_dsc_loss = dsc_loss_l + dsc_loss_u
                dsc_optim.zero_grad()
                total_dsc_loss.backward()
                dsc_optim.step()

                # Sample new batch of data while training adversarial network
                if j < self.dsc_iterations - 1:
                    batch_sequences_l, batch_lengths_l, _ = next(
                        iter(dataloader_l))
                    batch_sequences_u, batch_length_u, _ = next(
                        iter(dataloader_u))

                    if torch.cuda.is_available():
                        batch_sequences_l = batch_sequences_l.to(self.device)
                        batch_lengths_l = batch_lengths_l.to(self.device)
                        batch_sequences_u = batch_sequences_u.to(self.device)
                        batch_length_u = batch_length_u.to(self.device)

            if (train_iter % dataset_size == 0):
                print("Initiating Early Stopping")
                early_stopping(
                    tl_loss, task_learner
                )  # TODO: Review. Should this be the metric we early stop on?

                if early_stopping.early_stop:
                    print(
                        f'Early stopping at {train_iter}/{train_iterations} training iterations'
                    )
                    break

            if (train_iter > 0) & (epoch == 1
                                   or train_iter % dataset_size == 0):
                if train_iter % dataset_size == 0:
                    val_metrics = self.evaluation(task_learner=task_learner,
                                                  dataloader=dataloader_v,
                                                  task_type=self.task_type)

                    val_string = f'Task Learner ({self.task_type}) Validation ' + f'Scores:\nF1: Macro {val_metrics["f1 macro"]*100:0.2f}% Micro {val_metrics["f1 micro"]*100:0.2f}%\n' if self.task_type == 'SEQ' else f'Accuracy {val_metrics["accuracy"]*100:0.2f}'
                    print(val_string)

            if (train_iter > 0) & (train_iter % dataset_size == 0):
                # Completed an epoch
                train_iter_str = f'Train Iter {train_iter} - Losses (TL-{self.task_type} {tl_loss:0.2f} | SVAE {total_svae_loss:0.2f} | Disc {total_dsc_loss:0.2f} | Learning rates: TL ({tl_optim.param_groups[0]["lr"]})'
                print(train_iter_str)

                print(f'Completed epoch: {epoch}')
                epoch += 1

        # Evaluation at the end of the first training cycle
        test_metrics = self.evaluation(task_learner=task_learner,
                                       dataloader=dataloader_t,
                                       task_type='SEQ')

        f1_macro_1 = test_metrics['f1 macro']

        # SVAE and Discriminator need to be evaluated on the TL metric n+1 split from their current training split
        # So, data needs to be sampled via SVAAL and then the TL retrained. The final metric from the retrained TL is then used to
        # optimise the SVAE and Discriminator. For this optimisation problem, the TL parameters are fixed.

        # Sample data via SVAE and Discriminator
        sampled_indices, _, _ = self.sample_adversarial(
            svae=svae,
            discriminator=discriminator,
            data=dataloader_u,
            indices=unlabelled_indices,
            cuda=True)  # TODO: review usage of indices arg

        # Add new samples to labelled dataset
        current_indices = list(current_indices) + list(sampled_indices)
        sampler = data.sampler.SubsetRandomSampler(current_indices)
        self.labelled_dataloader = data.DataLoader(self.datasets['train'],
                                                   sampler=sampler,
                                                   batch_size=self.batch_size,
                                                   drop_last=True)
        dataloader_l = self.labelled_dataloader  # to maintain naming conventions
        print(
            f'{datetime.now()}: Indices - Labelled {len(current_indices)} Unlabelled {len(unlabelled_indices)} Total {len(all_indices)}'
        )

        task_learner = TaskLearner(**params['tl'],
                                   vocab_size=self.vocab_size,
                                   tagset_size=self.tagset_size,
                                   task_type=self.task_type).to(self.device)
        if self.task_type == 'SEQ':
            tl_loss_fn = nn.NLLLoss().to(self.device)
        if self.task_type == 'CLF':
            tl_loss_fn = nn.CrossEntropyLoss().to(self.device)
        tl_optim = optim.SGD(
            task_learner.parameters(),
            lr=params['tl_learning_rate'])  #, momentum=0, weight_decay=0.1)
        task_learner.train()

        early_stopping = EarlyStopping(
            patience=self.config['Train']['es_patience'],
            verbose=True,
            path="checkpoints/checkpoint.pt")

        print(f'{datetime.now()}: Task Learner initialised successfully')

        # Train Task Learner on adversarially selected samples
        dataset_size = len(dataloader_l)
        train_iterations = dataset_size * (params['epochs'] + 1)

        epoch = 1
        for train_iter in tqdm(range(train_iterations),
                               desc='Training iteration'):
            batch_sequences_l, batch_lengths_l, batch_tags_l = next(
                iter(dataloader_l))

            if torch.cuda.is_available():
                batch_sequences_l = batch_sequences_l.to(self.device)
                batch_lengths_l = batch_lengths_l.to(self.device)
                batch_tags_l = batch_tags_l.to(self.device)

            # Strip off tag padding and flatten
            # Don't do sequences here as its done in the forward pass of the seq2seq models
            batch_tags_l = trim_padded_seqs(batch_lengths=batch_lengths_l,
                                            batch_sequences=batch_tags_l,
                                            pad_idx=self.pad_idx).view(-1)

            # Task Learner Step
            tl_optim.zero_grad()
            tl_preds = task_learner(batch_sequences_l, batch_lengths_l)
            tl_loss = tl_loss_fn(tl_preds, batch_tags_l)
            tl_loss.backward()
            tl_optim.step()

            if (train_iter % dataset_size == 0):
                print("Initiating Early Stopping")
                early_stopping(
                    tl_loss, task_learner
                )  # TODO: Review. Should this be the metric we early stop on?

                if early_stopping.early_stop:
                    print(
                        f'Early stopping at {train_iter}/{train_iterations} training iterations'
                    )
                    break

            if (train_iter > 0) & (epoch == 1
                                   or train_iter % dataset_size == 0):
                if train_iter % dataset_size == 0:
                    val_metrics = self.evaluation(task_learner=task_learner,
                                                  dataloader=dataloader_v,
                                                  task_type=self.task_type)

                    val_string = f'Task Learner ({self.task_type}) Validation ' + f'Scores:\nF1: Macro {val_metrics["f1 macro"]*100:0.2f}% Micro {val_metrics["f1 micro"]*100:0.2f}%\n' if self.task_type == 'SEQ' else f'Accuracy {val_metrics["accuracy"]*100:0.2f}'
                    print(val_string)

            if (train_iter > 0) & (train_iter % dataset_size == 0):
                # Completed an epoch
                train_iter_str = f'Train Iter {train_iter} - Losses (TL-{self.task_type} {tl_loss:0.2f} | Learning rate: TL ({tl_optim.param_groups[0]["lr"]})'
                print(train_iter_str)
                print(f'Completed epoch: {epoch}')
                epoch += 1

        # Compute test metrics
        test_metrics = self.evaluation(task_learner=task_learner,
                                       dataloader=dataloader_t,
                                       task_type='SEQ')

        print(
            f'{datetime.now()}: Test Eval.: F1 Scores - Macro {test_metrics["f1 macro"]*100:0.2f}% Micro {test_metrics["f1 micro"]*100:0.2f}%'
        )

        # return test_metrics["f1 macro"]
        # Should the output be maximum rate of the change between iter n and iter n+1 metrics? this makes more sense than just f1 macro?
        f1_macro_diff = test_metrics['f1 macro'] - f1_macro_1
        print(f'Macro f1 difference: {f1_macro_diff*100:0.2f}%')
        return f1_macro_diff