Exemple #1
0
def main(args):
    """
        Calc loss and perplexity on training and validation set
    """
    logging.info('Commencing Validation!')
    torch.manual_seed(42)
    np.random.seed(42)

    utils.init_logging(args)

    # Load dictionaries [for each language]
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')

    if torch.cuda.is_available() and args.cuda:
        model = model.cuda()
    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler


    train_loader = \
        torch.utils.data.DataLoader(train_dataset, num_workers = 1, collate_fn = train_dataset.collater,
                                    batch_sampler = BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                 0, shuffle = True, seed = 42))

    # Calculate validation loss
    train_perplexity = validate(args, model, criterion, train_dataset, 0)

    valid_perplexity = validate(args, model, criterion, valid_dataset, 0)
Exemple #2
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """

    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5,
                                                           patience=1)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer,
                                       scheduler)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        ## BPE Dropout
        # Set the seed to be equal to the epoch
        # (this way we guarantee same seeds over multiple training runs, but not for each training epoch)
        seed = epoch

        bpe_dropout_if_needed(seed, args.bpe_dropout)

        # Load the BPE (dropout-ed) training data
        train_dataset = load_data(
            split='train') if not args.train_on_tiny else load_data(
                split='tiny_train')
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()

            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(
            ), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity, valid_loss = validate(args, model, criterion,
                                                valid_dataset, epoch)
        model.train()

        # Scheduler step
        if args.adaptive_lr:
            scheduler.step(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, scheduler, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break
Exemple #3
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{
        **vars(args),
        **vars(state_dict['args'])
    })
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=1,
                                              collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(
                                                  test_dataset,
                                                  9999999,
                                                  args.batch_size,
                                                  1,
                                                  0,
                                                  shuffle=False,
                                                  seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(
        args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):
        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'],
                                        sample['src_lengths'])
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)
            prev_words = go_slice
            next_words = None

        for _ in range(args.max_len):
            with torch.no_grad():
                # Compute the decoder output by repeatedly feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)
            # Suppress <UNK>s
            _, next_candidates = torch.topk(decoder_out, 2, dim=-1)
            best_candidates = next_candidates[:, :, 0]
            backoff_candidates = next_candidates[:, :, 1]
            next_words = torch.where(best_candidates == tgt_dict.unk_idx,
                                     backoff_candidates, best_candidates)
            prev_words = torch.cat([go_slice, next_words], dim=1)

        # Segment into sentences
        decoded_batch = next_words.cpu().numpy()
        output_sentences = [
            decoded_batch[row, :] for row in range(decoded_batch.shape[0])
        ]
        assert (len(output_sentences) == len(sample['id'].data))

        # Remove padding
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append([])
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        # Save translations
        assert (len(output_sentences) == len(sample['id'].data))
        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                out_file.write(all_hyps[sent_id] + '\n')
def main(args):
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{}'.format(args.target_lang)),
        src_dict=src_dict, tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(
        test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater,
        batch_sampler=BatchSampler(
            test_dataset, args.max_tokens, args.batch_size, args.distributed_world_size,
            args.distributed_rank, shuffle=False, seed=args.seed))

    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict).cuda()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {}'.format(args.checkpoint_path))

    translator = SequenceGenerator(
        model, tgt_dict, beam_size=args.beam_size, maxlen=args.max_len, stop_early=eval(args.stop_early),
        normalize_scores=eval(args.normalize_scores), len_penalty=args.len_penalty, unk_penalty=args.unk_penalty,
    )

    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)
    for i, sample in enumerate(progress_bar):
        sample = utils.move_to_cuda(sample)
        with torch.no_grad():
            hypos = translator.generate(sample['src_tokens'], sample['src_lengths'])
        for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)):
            src_tokens = utils.strip_pad(sample['src_tokens'].data[i, :], tgt_dict.pad_idx)
            has_target = sample['tgt_tokens'] is not None
            target_tokens = utils.strip_pad(sample['tgt_tokens'].data[i, :], tgt_dict.pad_idx).int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            target_str = tgt_dict.string(target_tokens, args.remove_bpe) if has_target else ''

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, colored(target_str, 'green')))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}'.format(sample_id, colored(hypo_str, 'blue')))
                    print('P-{}\t{}'.format(sample_id, ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))))
                    print('A-{}\t{}'.format(sample_id, ' '.join(map(lambda x: str(x.item()), alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tgt_dict.binarize(target_str, word_tokenize, add_if_not_exist=True)
Exemple #5
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported.')
    torch.manual_seed(args.seed)
    torch.cuda.set_device(args.device_id)
    utils.init_logging(args)

    if args.distributed_world_size > 1:
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            init_method=args.distributed_init_method,
            world_size=args.distributed_world_size,
            rank=args.distributed_rank)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({}) with {} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({}) with {} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(split='train')
    valid_dataset = load_data(split='valid')

    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict).cuda()
    logging.info('Built a model with {} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum').cuda()

    # Build an optimizer and a learning rate schedule
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, min_lr=args.min_lr, factor=args.lr_shrink)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler)
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            num_workers=args.num_workers,
            collate_fn=train_dataset.collater,
            batch_sampler=BatchSampler(train_dataset,
                                       args.max_tokens,
                                       args.batch_size,
                                       args.distributed_world_size,
                                       args.distributed_rank,
                                       shuffle=True,
                                       seed=args.seed))

        model.train()
        stats = {
            'loss': 0.,
            'lr': 0.,
            'num_tokens': 0.,
            'batch_size': 0.,
            'grad_norm': 0.,
            'clip': 0.
        }
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=(args.distributed_rank != 0))

        for i, sample in enumerate(progress_bar):
            sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue

            # Forward and backward pass
            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = criterion(output.view(-1, output.size(-1)),
                             sample['tgt_tokens'].view(-1))
            optimizer.zero_grad()
            loss.backward()

            # Reduce gradients across all GPUs
            if args.distributed_world_size > 1:
                utils.reduce_grads(model.parameters())
                total_loss, num_tokens, batch_size = list(
                    map(
                        sum,
                        zip(*utils.all_gather_list([
                            loss.item(), sample['num_tokens'],
                            len(sample['src_tokens'])
                        ]))))
            else:
                total_loss, num_tokens, batch_size = loss.item(
                ), sample['num_tokens'], len(sample['src_tokens'])

            # Normalize gradients by number of tokens and perform clipping
            for param in model.parameters():
                param.grad.data.div_(num_tokens)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()

            # Update statistics for progress bar
            stats['loss'] += total_loss / num_tokens / math.log(2)
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Adjust learning rate based on validation loss
        valid_loss = validate(args, model, criterion, valid_dataset, epoch)
        lr_scheduler.step(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch,
                                  valid_loss)
        if optimizer.param_groups[0]['lr'] <= args.min_lr:
            logging.info('Done training!')
            break
def main(args):
    """ Main function. Visualizes attention weight arrays as nifty heat-maps. """
    mpl.rc('font', family='VL Gothic')

    torch.manual_seed(42)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    print('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    print('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    vis_loader = torch.utils.data.DataLoader(test_dataset,
                                             num_workers=1,
                                             collate_fn=test_dataset.collater,
                                             batch_sampler=BatchSampler(
                                                 test_dataset,
                                                 None,
                                                 1,
                                                 1,
                                                 0,
                                                 shuffle=False,
                                                 seed=42))

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.load_state_dict(state_dict['model'])
    print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))

    # Store attention weight arrays
    attn_records = list()

    # Iterate over the visualization set
    for i, sample in enumerate(vis_loader):
        if args.cuda:
            sample = utils.move_to_cuda(sample)
        if len(sample) == 0:
            continue

        # Perform forward pass
        output, attn_weights = model(sample['src_tokens'],
                                     sample['src_lengths'],
                                     sample['tgt_inputs'])
        attn_records.append((sample, attn_weights))

        # Only visualize the first 10 sentence pairs
        if i >= 10:
            break

    # Generate heat-maps and store them at the designated location
    if not os.path.exists(args.vis_dir):
        os.makedirs(args.vis_dir)

    for record_id, record in enumerate(attn_records):
        # Unpack
        sample, attn_map = record
        src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx)
        tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx)
        # Convert indices into word tokens
        src_str = src_dict.string(src_ids).split(' ') + ['<EOS>']
        tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>']

        # Generate heat-maps
        attn_map = attn_map.squeeze(dim=0).transpose(1,
                                                     0).cpu().detach().numpy()

        attn_df = pd.DataFrame(attn_map, index=src_str, columns=tgt_str)

        sns.heatmap(attn_df,
                    cmap='Blues',
                    linewidths=0.25,
                    vmin=0.0,
                    vmax=1.0,
                    xticklabels=True,
                    yticklabels=True,
                    fmt='.3f')
        plt.yticks(rotation=0)
        plot_path = os.path.join(args.vis_dir,
                                 'sentence_{:d}.png'.format(record_id))
        plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight')
        plt.clf()

    print(
        'Done! Visualized attention maps have been saved to the \'{:s}\' directory!'
        .format(args.vis_dir))
    def unk_consumer(word, idx):
        if idx == dictionary.unk_idx and word != dictionary.unk_word:
            unk_counter.update([word])

    tokens_list = []
    with open(input_file, 'r') as inf:
        for line in inf:
            tokens = dictionary.binarize(line.strip(),
                                         word_tokenize,
                                         append_eos,
                                         consumer=unk_consumer)
            nsent, ntok = nsent + 1, ntok + len(tokens)
            tokens_list.append(tokens.numpy())

    with open(output_file, 'wb') as outf:
        pickle.dump(tokens_list, outf, protocol=pickle.HIGHEST_PROTOCOL)
        logging.info(
            'Built a binary dataset for {}: {} sentences, {} tokens, {:.3f}% replaced by unknown token'
            .format(input_file, nsent, ntok,
                    100.0 * sum(unk_counter.values()) / ntok,
                    dictionary.unk_word))


if __name__ == '__main__':
    args = get_args()
    utils.init_logging(args)
    logging.info('COMMAND: %s' % ' '.join(sys.argv))
    logging.info('Arguments: {}'.format(vars(args)))
    main(args)
Exemple #8
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)  # sets the random seed from pytorch random number generators
    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict, tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(test_dataset, 9999999,
                                                                         args.batch_size, 1, 0, shuffle=False,
                                                                         seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch
        batch_size = sample['src_tokens'].shape[0]  # returns number of rows from sample['src_tokens']
        searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)]
        # beam search with beamsize, max seq length and unkindex --> do this B times

        with torch.no_grad():  # disables gradient calculation
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
            # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent?
            #  encoder_out = self.encoder(src_tokens, src_lengths) decoder_out = self.decoder(tgt_inputs, encoder_out)
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            # vector of ones of length sample['src_tokens'] rows and 1 col filled with eos_idx casted to type sample[
            # 'src_tokens']
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)  # decoder out = decoder(tgt_inputs, encoder_out)

            # __QUESTION 2: Why do we keep one top candidate more than the beam size?
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1, dim=-1)
            # returns largest k elements (here beam_size+1) of the input torch.log(torch.softmax(decoder_out,
            # dim=2) in dimension -1 + 1 is taken because the input is given in logarithmic notation

        #  Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j + 1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j + 1]
                next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
                log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:, i, :]
                lstm_out = encoder_out['src_out'][0][:, i, :]
                final_hidden = encoder_out['src_out'][1][:, i, :]
                final_cell = encoder_out['src_out'][2][:, i, :]
                try:
                    mask = encoder_out['src_mask'][i, :]
                except TypeError:
                    mask = None

                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell,
                                      mask, torch.cat((go_slice[i], next_word)), log_p, 1)

                # add normalization here according to paper
                lp = normalize(node.length)
                score = node.eval()/lp
                # Add diverse
                score = diverse(score, j)
                # __QUESTION 3: Why do we add the node with a negative score?
                searches[i].add(-score, node)

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len - 1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]
            if nodes == []:
                break  # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])
            encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes], dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 2
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size + 1,
                                                    dim=-1)

            #  Create number of beam_size next nodes for every current node
            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    best_candidate = next_candidates[i, :, j]
                    backoff_candidate = next_candidates[i, :, j + 1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j + 1]
                    next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
                    log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                    log_p = log_p[-1]
                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))

                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction?

                    # Store the node as final if EOS is generated
                    if next_word[-1] == tgt_dict.eos_idx:
                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                                                                     next_word)), node.logp,
                                              node.length)
                        # Add length normalization
                        lp = normalize(node.length)
                        score = node.eval()/lp
                        # add diverse
                        score = diverse(score, j)
                        search.add_final(-score, node)

                    # Add the node to current nodes for next iteration
                    else:
                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                                                                     next_word)), node.logp + log_p,
                                              node.length + 1)
                        # Add length normalization
                        lp = normalize(node.length)
                        score = node.eval()/lp
                        # add diverse
                        score = diverse(score, j)
                        search.add(-score, node)

            # __QUESTION 5: What happens internally when we prune our beams?
            # How do we know we always maintain the best sequences?
            for search in searches:
                search.prune()

        # Segment into 1 best sentences
        #best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches])

        # segment 3 best oneliner
        best_sents = torch.stack([n[1].sequence[1:] for s in searches for n in s.get_best()])

        # segment into n best sentences
        #for s in searches:
        #    for n in s.get_best():
        #        best_sents = torch.stack([n[1].sequence[1:].cpu()])
        print('n best sents', best_sents)

        # concatenates a sequence of tensors, gets the one best here, so we should use the n-best (3 best) here
        decoded_batch = best_sents.numpy()

        output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]  # predicts first eos token
            if len(first_eos) > 0:  # checks if the first eos token is not the beginning (position 0)
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        # here: adapt so that it takes the 3-best (aka n-best), % used for no overflow
        for ii, sent in enumerate(output_sentences):
            # all_hyps[int(sample['id'].data[ii])] = sent
            # variant for 3-best
            all_hyps[(int(sample['id'].data[int(ii / 3)]), int(ii % 3))] = sent

    # Write to file (write 3 best per sentence together)
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                # variant for 1-best
                # out_file.write(all_hyps[sent_id] + '\n')
                # variant for 3-best
                out_file.write(all_hyps[(int(sent_id / 3), int(sent_id % 3))] + '\n')
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{
        **vars(args),
        **vars(state_dict['args'])
    })
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=1,
                                              collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(
                                                  test_dataset,
                                                  9999999,
                                                  args.batch_size,
                                                  1,
                                                  0,
                                                  shuffle=False,
                                                  seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(
        args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch
        batch_size = sample['src_tokens'].shape[0]
        searches = [
            BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx)
            for i in range(batch_size)
        ]

        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'],
                                        sample['src_lengths'])
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)

            # __QUESTION 1: What happens here and what do 'log_probs' and 'next_candidates' contain?
            decoder_out = length_normalization(
                decoder_out)  #applies length normalization
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)

        # Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                # __QUESTION 2: Why do we need backoff candidates?
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j + 1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j + 1]
                next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_candidate, best_candidate)
                log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                    backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:, i, :]
                lstm_out = encoder_out['src_out'][0][:, i, :]
                final_hidden = encoder_out['src_out'][1][:, i, :]
                final_cell = encoder_out['src_out'][2][:, i, :]
                try:
                    mask = encoder_out['src_mask'][i, :]
                except TypeError:
                    mask = None

                # __QUESTION 3: What happens internally when we add a new beam search node?
                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden,
                                      final_cell, mask,
                                      torch.cat(
                                          (go_slice[i], next_word)), log_p, 1)
                searches[i].add(-node.eval(), node)

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len - 1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]
            if nodes == []:
                break  # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])
            encoder_out["src_embeddings"] = torch.stack(
                [node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes],
                                       dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes],
                                     dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack(
                    [node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 1
            decoder_out = length_normalization(
                decoder_out)  #length normalization function
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(length_normalization(decoder_out), dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)

            # Create number of beam_size next nodes for every current node
            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    # see __QUESTION 2
                    best_candidate = next_candidates[i, :, j]
                    backoff_candidate = next_candidates[i, :, j + 1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j + 1]
                    next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                            backoff_candidate, best_candidate)
                    log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_log_p, best_log_p)
                    log_p = log_p[-1]
                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))

                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: Why do we treat nodes that generated the end-of-sentence token differently?

                    # Store the node as final if EOS is generated
                    if next_word[-1] == tgt_dict.eos_idx:
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp, node.length)
                        search.add_final(-node.eval(), node)

                    # Add the node to current nodes for next iteration
                    else:
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp + log_p, node.length + 1)
                        search.add(-node.eval(), node)

            # __QUESTION 5: What happens internally when we prune our beams?
            # How do we know we always maintain the best sequences?
            for search in searches:
                search.prune()

        # Segment into sentences
        best_sents = torch.stack(
            [search.get_best()[1].sequence[1:] for search in searches])
        decoded_batch = best_sents.numpy()

        output_sentences = [
            decoded_batch[row, :] for row in range(decoded_batch.shape[0])
        ]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                out_file.write(all_hyps[sent_id] + '\n')
Exemple #10
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """
    logging.info('Commencing training!')
    torch.manual_seed(42)
    np.random.seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict, tgt_dict=tgt_dict)

    train_dataset = load_data(split='train') if not args.train_on_tiny else load_data(split='tiny_train')
    valid_dataset = load_data(split='valid')

    # yichao: enable cuda
    use_cuda = torch.cuda.is_available() and args.device == 'cuda'
    device = torch.device("cuda" if use_cuda else "cpu")
    print("===> Using %s" % device)

    # Build model and optimization criterion
    # yichao: enable cuda, i.e. add .to(device)
    model = models.build_model(args, src_dict, tgt_dict).to(device)
    logging.info('Built a model with {:d} parameters'.format(sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum').to(device)

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0

        # Display progress
        progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):

            if len(sample) == 0:
                continue
            model.train()

            '''
            ___QUESTION-1-DESCRIBE-F-START___
            Describe what the following lines of code do.
            '''
            # yichao: enable cuda
            sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'], sample['tgt_tokens'] = \
                sample['src_tokens'].to(device), sample['src_lengths'].to(device), \
                sample['tgt_inputs'].to(device), sample['tgt_tokens'].to(device)

            output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'])

            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()
            '''___QUESTION-1-DESCRIBE-F-END___'''

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix({key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items()},
                                     refresh=True)

        logging.info('Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.4g}'.format(
            value / len(progress_bar)) for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, criterion, valid_dataset, epoch)
        model.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info('No validation set improvements observed for {:d} epochs. Early stop!'.format(args.patience))
            break
Exemple #11
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """

    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    model_rev = models.build_model(args, tgt_dict, src_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    criterion2 = nn.MSELoss(reduction='sum')
    if args.cuda:
        model = model.cuda()
        model_rev = model_rev.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    utils.load_checkpoint_rev(args, model_rev, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        model_rev.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()

            (output, att), src_out = model(sample['src_tokens'],
                                           sample['src_lengths'],
                                           sample['tgt_inputs'])
            # print(sample['src_lengths'])
            # print(sample['tgt_inputs'].size())
            # print(sample['src_tokens'].size())
            src_inputs = sample['src_tokens'].clone()
            src_inputs[0, 1:src_inputs.size(1)] = sample['src_tokens'][0, 0:(
                src_inputs.size(1) - 1)]
            src_inputs[0, 0] = sample['src_tokens'][0, src_inputs.size(1) - 1]
            tgt_lengths = sample['src_lengths'].clone(
            )  #torch.tensor([sample['tgt_tokens'].size(1)])
            tgt_lengths += sample['tgt_inputs'].size(
                1) - sample['src_tokens'].size(1)
            # print(tgt_lengths)
            # print(sample['num_tokens'])

            # if args.cuda:
            #     tgt_lengths = tgt_lengths.cuda()
            (output_rev,
             att_rev), src_out_rev = model_rev(sample['tgt_tokens'],
                                               tgt_lengths, src_inputs)

            # notice that those are without masks already
            # print(sample['tgt_tokens'].view(-1))
            d, d_rev = get_diff(att, src_out, att_rev, src_out_rev)

            # print(sample['src_tokens'].size())
            # print(sample['tgt_inputs'].size())
            # print(att.size())
            # print(src_out.size())
            # print(acontext.size())
            # print(src_out_rev.size())
            # # print(sample['tgt_inputs'].dtype)
            # # print(sample['src_lengths'])
            # # print(sample['src_tokens'])
            # # print('output %s' % str(output.size()))
            # # print(att)
            # # print(len(sample['src_lengths']))
            # print(d)
            # print(d_rev)
            # print(criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']))
            # print(att2)
            # output=output.cpu().detach().numpy()
            # output=torch.from_numpy(output).cuda()
            # output_rev=output_rev.cpu().detach().numpy()
            # output_rev=torch.from_numpy(output_rev).cuda()
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])  + d +\
                criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) +d_rev
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            # loss_rev = \
            #     criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths)
            # loss_rev.backward()
            # grad_norm_rev = torch.nn.utils.clip_grad_norm_(model_rev.parameters(), args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = (
                loss - d - d_rev).item(), sample['num_tokens'], len(
                    sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            # stats['loss_rev'] += loss_rev.item() * len(sample['src_lengths']) / sample['src_tokens'].size(0) / sample['src_tokens'].size(1)
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, model_rev, criterion,
                                    valid_dataset, epoch)
        model.train()
        model_rev.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, model_rev, optimizer, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break
Exemple #12
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict, tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(test_dataset, 9999999,
                                                                         args.batch_size, 1, 0, shuffle=False,
                                                                         seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)


    # Iterate over the test set
    all_hyps = {}

    count = 0 

    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch

        batch_size = sample['src_tokens'].shape[0]
        searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)]

        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
            # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent?
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])

            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)

            # __QUESTION 2: Why do we keep one top candidate more than the beam size?
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size+1, dim=-1)

        # Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j+1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j+1]

                # For task 3 length normalization
                # To calculate the score after length normalization
                lp = (math.pow( (5 + log_probs.shape[1]), args.alpha ))/math.pow( (5+1), args.alpha)

                next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)

                log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:,i,:]
                lstm_out = encoder_out['src_out'][0][:,i,:]
                final_hidden = encoder_out['src_out'][1][:,i,:]
                final_cell = encoder_out['src_out'][2][:,i,:]
                try:
                    mask = encoder_out['src_mask'][i,:]
                except TypeError:
                    mask = None

                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell,
                                      mask, torch.cat((go_slice[i], next_word)), log_p, 1)

                # __QUESTION 3: Why do we add the node with a negative score?

                # For task 3 and task 4 diversity promoting beam search 
                # When alpha set to 0 and gamma set to 0, the is the original code
                # When alpha set to non-zero and gamma set to 0, this is for task 3
                # When alpha set to 0 or non-zero and gamma non-zero, this is for task 4
                searches[i].add(-(node.eval()/lp-(j+1)*args.gamma), node)

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len-1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]

            if nodes == []:
                break # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])

            encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes], dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 2
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size+1, dim=-1)

            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    best_candidate = next_candidates[i, :, j]
                    backoff_candidate = next_candidates[i, :, j+1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j+1]

                    # For task 3 length normalization
                    # To calculate the score after length normalization
                    lp = (math.pow( (5 + log_probs.shape[1]), args.alpha ))/math.pow( (5+1), args.alpha)

                    next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)

                    log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                    log_p = log_p[-1]

                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))


                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction?

                    # Store the node as final if EOS is generated
                    if next_word[-1 ] == tgt_dict.eos_idx:
                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                              next_word)), node.logp, node.length)

                        # For task 4 diversity promoting beam search.
                        # Gamma is the weight to control the influences of rank on the score.
                        # (j+1) is the rank for the current candidate.
                        search.add_final(-(node.eval()/lp-(j+1)*args.gamma), node)

                    # Add the node to current nodes for next iteration
                    else:

                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                              next_word)), node.logp + log_p, node.length + 1)

                        # For task 4 diversity promoting beam search.
                        # Gamma is the weight to control the influences of rank on the score.
                        # (j+1) is the rank for the current candidate.
                        search.add(-(node.eval()/lp-(j+1)*args.gamma), node)

                # print ("loop")


            # __QUESTION 5: What happens internally when we prune our beams?
            # How do we know we always maintain the best sequences?
            for search in searches:
                search.prune()

        # Segment into sentences

        best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches])
        decoded_batch = best_sents.numpy()

        # From line 239 to line 244, the code is for task 4 diversity promoting beam search.
        # To get the n-best lists
        # top_n_sent = []
        # for search in searches :
        #     top_n = search.get_top_n(args.beam_size)
        #     for i in range(args.beam_size) :
        #         top_n_sent.append(top_n[i][1].sequence[1:])
        # best_top_sents = torch.stack(top_n_sent)

        # Line 248, the code is for task 4 diversity promoting beam search.
        # To get the n-best lists
        # decoded_batch = best_top_sents.numpy()

        output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

        # From line 270 to line 272, the code is for task 4 diversity promoting beam search.
        # To get the n-best lists
        # for sent in enumerate(output_sentences):
        #     all_hyps[int(count)] = sent
        #     count = count+1




    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):

                out_file.write(all_hyps[sent_id] + '\n')