Exemple #1
0
def main(args):

    splits = ['train', 'valid'] + (['dev'] if args.test else [])
    print(args)
    #Load dataset
    datasets = OrderedDict()
    for split in splits:
        datasets[split] = seq_data(data_dir=args.data_dir,
                                   split=split,
                                   mt=args.mt,
                                   create_data=args.create_data,
                                   max_src_len=args.max_src_length,
                                   max_tgt_len=args.max_tgt_length,
                                   min_occ=args.min_occ)
    print('Data OK')
    #Load model
    model = SVAE(
        vocab_size=datasets['train'].vocab_size,
        embed_dim=args.embedding_dimension,
        hidden_dim=args.hidden_dimension,
        latent_dim=args.latent_dimension,

        #word_drop=args.word_dropout,
        teacher_forcing=args.teacher_forcing,
        dropout=args.dropout,
        n_direction=args.bidirectional,
        n_parallel=args.n_layer,
        attn=args.attention,
        max_src_len=args.max_src_length,  #influence in inference stage
        max_tgt_len=args.max_tgt_length,
        sos_idx=datasets['train'].sos_idx,
        eos_idx=datasets['train'].eos_idx,
        pad_idx=datasets['train'].pad_idx,
        unk_idx=datasets['train'].unk_idx)
    if args.fasttext:
        prt = torch.load(args.data_dir + '/prt_fasttext.model')
        model.load_prt(prt)
    print('Model OK')
    if torch.cuda.is_available():
        model = model.cuda()
    device = model.device
    #Training phase with validation(earlystopping)
    tracker = Tracker(patience=10,
                      verbose=True)  #record training history & es function
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    step = 0

    for epoch in range(args.epochs):
        for split in splits:
            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.n_batch,
                                     shuffle=(split == 'train'),
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())
            if split == 'train':
                model.train()
            else:
                model.eval()

            #Executing
            for i, data in enumerate(data_loader):
                src, srclen,  tgt, tgtlen = \
                     data['src'], data['srclen'], data['tgt'], data['tgtlen']
                #FP
                logits, (mu, logv,
                         z), generations = model(src, srclen, tgt, tgtlen,
                                                 split)

                #FP for groundtruth
                #h_pred, h_tgt = model.forward_gt(generations, tgt, tgtlen)

                #LOSS(weighted)
                NLL, KL, KL_W = model.loss(logits, tgt.to(device),
                                           data['tgtlen'], mu, logv, step,
                                           args.k, args.x0, args.af)
                #GLOBAL = model.global_loss(h_pred, h_tgt)
                GLOBAL = 0

                loss = (NLL + KL * KL_W + GLOBAL) / data['src'].size(0)
                #BP & OPTIM
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                #RECORD & RESULT(batch)
                if i % 50 == 0 or i + 1 == len(data_loader):
                    #NLL.data = torch.cuda.FloatTensor([NLL.data])
                    #KL.data = torch.cuda.FloatTensor([KL.data])
                    print(
                        "{} Phase - Batch {}/{}, Loss: {}, NLL: {}, KL: {}, KL-W: {}, G: {}"
                        .format(split.upper(), i,
                                len(data_loader) - 1, loss, NLL, KL, KL_W,
                                GLOBAL))
                tracker._elbo(torch.Tensor([loss]))
                if split == 'valid':
                    tracker.record(tgt, generations, datasets['train'].i2w,
                                   datasets['train'].pad_idx,
                                   datasets['train'].eos_idx,
                                   datasets['train'].unk_idx, z)

            #SAVING & RESULT(epoch)
            if split == 'valid':
                tracker.dumps(epoch, args.dump_file)  #dump the predicted text.
            else:
                tracker._save_checkpoint(
                    epoch, args.model_file,
                    model.state_dict())  #save the checkpooint
            print("{} Phase - Epoch {} , Mean ELBO: {}".format(
                split.upper(), epoch, torch.mean(tracker.elbo)))

            tracker._purge()
class ModularTrainer(Sampler):
    def __init__(self):
        self.config = load_config()
        self.model_config = self.config['Models']

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # Model
        self.task_type = self.config['Utils']['task_type']
        self.max_sequence_length = self.config['Utils'][
            self.task_type]['max_sequence_length']

        # Real data
        self.data_name = self.config['Utils'][self.task_type]['data_name']
        self.data_splits = self.config['Utils'][self.task_type]['data_split']
        self.pad_idx = self.config['Utils']['special_token2idx']['<PAD>']

        # Test run properties
        self.epochs = self.config['Train']['epochs']
        self.svae_iterations = self.config['Train']['svae_iterations']

        self.kfold_xval = False

    def _init_data(self, batch_size=None):
        if batch_size is None:
            batch_size = self.config['Train']['batch_size']

        # Load pre-processed data
        path_data = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data',
                                 self.task_type, self.data_name, 'pretrain',
                                 'data.json')
        path_vocab = os.path.join('/home/tyler/Desktop/Repos/s-vaal/data',
                                  self.task_type, self.data_name, 'pretrain',
                                  'vocab.json')  # not vocabs
        data = load_json(path_data)

        self.vocab = load_json(
            path_vocab
        )  # Required for decoding sequences for interpretations. TODO: Find suitable location... or leave be...
        self.vocab_size = len(self.vocab['word2idx'])

        self.idx2word = self.vocab['idx2word']
        self.word2idx = self.vocab['word2idx']

        self.datasets = dict()
        if self.kfold_xval:
            # Perform k-fold cross-validation
            # Join all datasets and then randomly assign train/val/test
            print('hello')

            for split in self.data_splits:
                print(data[split][self.x_y_pair_name])

        else:
            for split in self.data_splits:
                # Access data
                split_data = data[split]
                # print(split_data)
                # Convert lists of encoded sequences into tensors and stack into one large tensor
                split_inputs = torch.stack([
                    torch.tensor(value['input'])
                    for key, value in split_data.items()
                ])
                split_targets = torch.stack([
                    torch.tensor(value['target'])
                    for key, value in split_data.items()
                ])
                # Create torch dataset from tensors
                split_dataset = RealDataset(sequences=split_inputs,
                                            tags=split_targets)
                # Add to dictionary
                self.datasets[split] = split_dataset  #split_dataloader

                # Create torch dataloader generator from dataset
                if split == 'test':
                    self.test_dataloader = DataLoader(dataset=split_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      num_workers=0)
                if split == 'valid':
                    self.val_dataloader = DataLoader(dataset=split_dataset,
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=0)
                if split == 'test':
                    self.train_dataloader = DataLoader(dataset=split_dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=0)

        print(f'{datetime.now()}: Data loaded succesfully')

    def _init_svae_model(self):
        self.svae = SVAE(**self.model_config['SVAE']['Parameters'],
                         vocab_size=self.vocab_size).to(self.device)
        self.svae_optim = optim.Adam(
            self.svae.parameters(),
            lr=self.model_config['SVAE']['learning_rate'])
        self.svae.train()
        print(f'{datetime.now()}: Initialised SVAE successfully')

    def interpolate(self, start, end, steps):

        interpolation = np.zeros((start.shape[0], steps + 2))

        for dim, (s, e) in enumerate(zip(start, end)):
            interpolation[dim] = np.linspace(s, e, steps + 2)

        return interpolation.T

    def _idx2word_inf(self, idx, i2w, pad_idx):
        # inf-erence
        sent_str = [str()] * len(idx)

        for i, sent in enumerate(idx):
            for word_id in sent:
                if word_id == pad_idx:
                    break

                sent_str[i] += i2w[str(word_id.item())] + " "
            sent_str[i] = sent_str[i].strip()
        return sent_str

    def _pretrain_svae(self):
        self._init_data()
        self._init_svae_model()

        tb_writer = SummaryWriter(
            comment=f"pretrain svae {self.data_name}",
            filename_suffix=f"pretrain svae {self.data_name}")
        print(f'{datetime.now()}: Training started')

        step = 0
        for epoch in range(1, self.config['Train']['epochs'] + 1, 1):
            for batch_inputs, batch_lengths, batch_targets in self.train_dataloader:
                if torch.cuda.is_available():
                    batch_inputs = batch_inputs.to(self.device)
                    batch_lengths = batch_lengths.to(self.device)
                    batch_targets = batch_targets.to(self.device)

                batch_size = batch_inputs.size(0)
                logp, mean, logv, _ = self.svae(batch_inputs,
                                                batch_lengths,
                                                pretrain=False)
                NLL_loss, KL_loss, KL_weight = self.svae.loss_fn(
                    logp=logp,
                    target=batch_targets,
                    length=batch_lengths,
                    mean=mean,
                    logv=logv,
                    anneal_fn=self.model_config['SVAE']['anneal_function'],
                    step=step,
                    k=self.model_config['SVAE']['k'],
                    x0=self.model_config['SVAE']['x0'])
                svae_loss = (NLL_loss + KL_weight * KL_loss) / batch_size
                self.svae_optim.zero_grad()
                svae_loss.backward()
                self.svae_optim.step()

                tb_writer.add_scalar('Loss/train/KLL', KL_loss, step)
                tb_writer.add_scalar('Loss/train/NLL', NLL_loss, step)
                tb_writer.add_scalar('Loss/train/Total', svae_loss, step)
                tb_writer.add_scalar('Utils/train/KL_weight', KL_weight, step)

                # Increment step after each batch of data
                step += 1

            if epoch % 1 == 0:
                print(
                    f'{datetime.now()}: Epoch {epoch} Loss {svae_loss:0.2f} Step {step}'
                )

            if epoch % 5 == 0:
                # Perform inference
                self.svae.eval()
                try:
                    samples, z = self.svae.inference(n=2)
                    print(*self._idx2word_inf(samples,
                                              i2w=self.idx2word,
                                              pad_idx=self.config['Utils']
                                              ['special_token2idx']['<PAD>']),
                          sep='\n')
                except:
                    traceback.print_exc(file=sys.stdout)
                self.svae.train()

        # Save final model
        save_path = os.getcwd() + '/best models/svae.pt'
        torch.save(self.svae.state_dict(), save_path)
        print(f'{datetime.now()}: Model saved')

        print(f'{datetime.now()}: Training finished')