示例#1
0
文件: train.py 项目: SouthDam/atmt
def validate(args, model, model_rev, criterion, valid_dataset, epoch):
    """ Validates model performance on a held-out development set. """
    valid_loader = \
        torch.utils.data.DataLoader(valid_dataset, num_workers=1, collate_fn=valid_dataset.collater,
                                    batch_sampler=BatchSampler(valid_dataset, args.max_tokens, args.batch_size, 1, 0,
                                                               shuffle=False, seed=42))
    model.eval()
    model_rev.eval()
    stats = OrderedDict()
    stats['valid_loss'] = 0
    stats['num_tokens'] = 0
    stats['batch_size'] = 0

    # Iterate over the validation set
    for i, sample in enumerate(valid_loader):
        if args.cuda:
            sample = utils.move_to_cuda(sample)
        if len(sample) == 0:
            continue
        with torch.no_grad():
            # Compute loss
            (output, attn_scores), src_out = model(sample['src_tokens'],
                                                   sample['src_lengths'],
                                                   sample['tgt_inputs'])

            src_inputs = sample['src_tokens'].clone()
            src_inputs[0, 1:src_inputs.size(1)] = sample['src_tokens'][0, 0:(
                src_inputs.size(1) - 1)]
            src_inputs[0, 0] = sample['src_tokens'][0, src_inputs.size(1) - 1]
            tgt_lengths = sample['src_lengths'].clone(
            )  #torch.tensor([sample['tgt_tokens'].size(1)])
            tgt_lengths += sample['tgt_inputs'].size(
                1) - sample['src_tokens'].size(1)
            (output_rev,
             attn_scores_rev), src_out_rev = model_rev(sample['tgt_tokens'],
                                                       tgt_lengths, src_inputs)

            d, d_rev = get_diff(attn_scores, src_out, attn_scores_rev,
                                src_out_rev)
            loss = criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) + d + \
                criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) + d_rev
        # Update tracked statistics
        stats['valid_loss'] += loss.item()
        stats['num_tokens'] += sample['num_tokens']
        stats['batch_size'] += len(sample['src_tokens'])

    # Calculate validation perplexity
    stats['valid_loss'] = stats['valid_loss'] / stats['num_tokens']
    perplexity = np.exp(stats['valid_loss'])
    stats['num_tokens'] = stats['num_tokens'] / stats['batch_size']

    logging.info('Epoch {:03d}: {}'.format(
        epoch, ' | '.join(key + ' {:.3g}'.format(value)
                          for key, value in stats.items())) +
                 ' | valid_perplexity {:.3g}'.format(perplexity))

    return perplexity
示例#2
0
    def forward(self, src_tokens, src_lengths):
        """ Performs a single forward pass through the instantiated encoder sub-network. """
        # Embed tokens and apply dropout
        batch_size, src_time_steps = src_tokens.size()
        if self.is_cuda:
            src_tokens = utils.move_to_cuda(src_tokens)
        src_embeddings = self.embedding(src_tokens)
        _src_embeddings = F.dropout(src_embeddings,
                                    p=self.dropout_in,
                                    training=self.training)

        # Transpose batch: [batch_size, src_time_steps, num_features] -> [src_time_steps, batch_size, num_features]
        src_embeddings = _src_embeddings.transpose(0, 1)

        # Pack embedded tokens into a PackedSequence
        packed_source_embeddings = nn.utils.rnn.pack_padded_sequence(
            src_embeddings, src_lengths)

        # Pass source input through the recurrent layer(s)
        packed_outputs, (
            final_hidden_states,
            final_cell_states) = self.lstm(packed_source_embeddings)

        # Unpack LSTM outputs and optionally apply dropout (dropout currently disabled)
        lstm_output, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs,
                                                          padding_value=0.)
        lstm_output = F.dropout(lstm_output,
                                p=self.dropout_out,
                                training=self.training)
        assert list(lstm_output.size()) == [
            src_time_steps, batch_size, self.output_dim
        ]  # sanity check

        if self.bidirectional:

            def combine_directions(outs):
                return torch.cat(
                    [outs[0:outs.size(0):2], outs[1:outs.size(0):2]], dim=2)

            final_hidden_states = combine_directions(final_hidden_states)
            final_cell_states = combine_directions(final_cell_states)

        # Generate mask zeroing-out padded positions in encoder inputs
        src_mask = src_tokens.eq(self.dictionary.pad_idx)
        print('src_embeddings:', _src_embeddings)
        print('final_hidden_states:', final_hidden_states)
        return {
            'src_embeddings': _src_embeddings.transpose(0, 1),
            'src_out': (lstm_output, final_hidden_states, final_cell_states),
            'src_mask': src_mask if src_mask.any() else None
        }
示例#3
0
def validate(args, model, criterion, valid_dataset, epoch):
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        num_workers=args.num_workers,
        collate_fn=valid_dataset.collater,
        batch_sampler=BatchSampler(valid_dataset,
                                   args.max_tokens,
                                   args.batch_size,
                                   args.distributed_world_size,
                                   args.distributed_rank,
                                   shuffle=True,
                                   seed=args.seed))

    model.eval()
    stats = {'valid_loss': 0, 'num_tokens': 0, 'batch_size': 0}
    progress_bar = tqdm(valid_loader,
                        desc='| Epoch {:03d}'.format(epoch),
                        leave=False)

    for i, sample in enumerate(progress_bar):
        sample = utils.move_to_cuda(sample)
        if len(sample) == 0:
            continue
        with torch.no_grad():
            output, attn_scores = model(sample['src_tokens'],
                                        sample['src_lengths'],
                                        sample['tgt_inputs'],
                                        sample['video_inputs'])
            loss = criterion(output.view(-1, output.size(-1)),
                             sample['tgt_tokens'].view(-1))
        stats['valid_loss'] += loss.item() / sample['num_tokens'] / math.log(2)
        stats['num_tokens'] += sample['num_tokens'] / len(sample['src_tokens'])
        stats['batch_size'] += len(sample['src_tokens'])
        progress_bar.set_postfix(
            {
                key: '{:.3g}'.format(value / (i + 1))
                for key, value in stats.items()
            },
            refresh=True)

    logging.info('Epoch {:03d}: {}'.format(
        epoch, ' | '.join(key + ' {:.3g}'.format(value / len(progress_bar))
                          for key, value in stats.items())))
    return stats['valid_loss'] / len(progress_bar)
示例#4
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """

    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5,
                                                           patience=1)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer,
                                       scheduler)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        ## BPE Dropout
        # Set the seed to be equal to the epoch
        # (this way we guarantee same seeds over multiple training runs, but not for each training epoch)
        seed = epoch

        bpe_dropout_if_needed(seed, args.bpe_dropout)

        # Load the BPE (dropout-ed) training data
        train_dataset = load_data(
            split='train') if not args.train_on_tiny else load_data(
                split='tiny_train')
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()

            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(
            ), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity, valid_loss = validate(args, model, criterion,
                                                valid_dataset, epoch)
        model.train()

        # Scheduler step
        if args.adaptive_lr:
            scheduler.step(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, scheduler, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break
示例#5
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{
        **vars(args),
        **vars(state_dict['args'])
    })
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=1,
                                              collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(
                                                  test_dataset,
                                                  9999999,
                                                  args.batch_size,
                                                  1,
                                                  0,
                                                  shuffle=False,
                                                  seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(
        args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):
        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'],
                                        sample['src_lengths'])
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)
            prev_words = go_slice
            next_words = None

        for _ in range(args.max_len):
            with torch.no_grad():
                # Compute the decoder output by repeatedly feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)
            # Suppress <UNK>s
            _, next_candidates = torch.topk(decoder_out, 2, dim=-1)
            best_candidates = next_candidates[:, :, 0]
            backoff_candidates = next_candidates[:, :, 1]
            next_words = torch.where(best_candidates == tgt_dict.unk_idx,
                                     backoff_candidates, best_candidates)
            prev_words = torch.cat([go_slice, next_words], dim=1)

        # Segment into sentences
        decoded_batch = next_words.cpu().numpy()
        output_sentences = [
            decoded_batch[row, :] for row in range(decoded_batch.shape[0])
        ]
        assert (len(output_sentences) == len(sample['id'].data))

        # Remove padding
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append([])
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        # Save translations
        assert (len(output_sentences) == len(sample['id'].data))
        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                out_file.write(all_hyps[sent_id] + '\n')
示例#6
0
def main(args):
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({}) with {} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({}) with {} words'.format(args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{}'.format(args.target_lang)),
        src_dict=src_dict, tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(
        test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater,
        batch_sampler=BatchSampler(
            test_dataset, args.max_tokens, args.batch_size, args.distributed_world_size,
            args.distributed_rank, shuffle=False, seed=args.seed))

    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict).cuda()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {}'.format(args.checkpoint_path))

    translator = SequenceGenerator(
        model, tgt_dict, beam_size=args.beam_size, maxlen=args.max_len, stop_early=eval(args.stop_early),
        normalize_scores=eval(args.normalize_scores), len_penalty=args.len_penalty, unk_penalty=args.unk_penalty,
    )

    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)
    for i, sample in enumerate(progress_bar):
        sample = utils.move_to_cuda(sample)
        with torch.no_grad():
            hypos = translator.generate(sample['src_tokens'], sample['src_lengths'])
        for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)):
            src_tokens = utils.strip_pad(sample['src_tokens'].data[i, :], tgt_dict.pad_idx)
            has_target = sample['tgt_tokens'] is not None
            target_tokens = utils.strip_pad(sample['tgt_tokens'].data[i, :], tgt_dict.pad_idx).int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            target_str = tgt_dict.string(target_tokens, args.remove_bpe) if has_target else ''

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, colored(target_str, 'green')))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}'.format(sample_id, colored(hypo_str, 'blue')))
                    print('P-{}\t{}'.format(sample_id, ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))))
                    print('A-{}\t{}'.format(sample_id, ' '.join(map(lambda x: str(x.item()), alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = tgt_dict.binarize(target_str, word_tokenize, add_if_not_exist=True)
示例#7
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported.')
    torch.manual_seed(args.seed)
    torch.cuda.set_device(args.device_id)
    utils.init_logging(args)

    if args.distributed_world_size > 1:
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            init_method=args.distributed_init_method,
            world_size=args.distributed_world_size,
            rank=args.distributed_rank)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({}) with {} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({}) with {} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(split='train')
    valid_dataset = load_data(split='valid')

    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict).cuda()
    logging.info('Built a model with {} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum').cuda()

    # Build an optimizer and a learning rate schedule
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, min_lr=args.min_lr, factor=args.lr_shrink)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler)
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            num_workers=args.num_workers,
            collate_fn=train_dataset.collater,
            batch_sampler=BatchSampler(train_dataset,
                                       args.max_tokens,
                                       args.batch_size,
                                       args.distributed_world_size,
                                       args.distributed_rank,
                                       shuffle=True,
                                       seed=args.seed))

        model.train()
        stats = {
            'loss': 0.,
            'lr': 0.,
            'num_tokens': 0.,
            'batch_size': 0.,
            'grad_norm': 0.,
            'clip': 0.
        }
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=(args.distributed_rank != 0))

        for i, sample in enumerate(progress_bar):
            sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue

            # Forward and backward pass
            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = criterion(output.view(-1, output.size(-1)),
                             sample['tgt_tokens'].view(-1))
            optimizer.zero_grad()
            loss.backward()

            # Reduce gradients across all GPUs
            if args.distributed_world_size > 1:
                utils.reduce_grads(model.parameters())
                total_loss, num_tokens, batch_size = list(
                    map(
                        sum,
                        zip(*utils.all_gather_list([
                            loss.item(), sample['num_tokens'],
                            len(sample['src_tokens'])
                        ]))))
            else:
                total_loss, num_tokens, batch_size = loss.item(
                ), sample['num_tokens'], len(sample['src_tokens'])

            # Normalize gradients by number of tokens and perform clipping
            for param in model.parameters():
                param.grad.data.div_(num_tokens)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()

            # Update statistics for progress bar
            stats['loss'] += total_loss / num_tokens / math.log(2)
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Adjust learning rate based on validation loss
        valid_loss = validate(args, model, criterion, valid_dataset, epoch)
        lr_scheduler.step(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch,
                                  valid_loss)
        if optimizer.param_groups[0]['lr'] <= args.min_lr:
            logging.info('Done training!')
            break
def main(args):
    """ Main function. Visualizes attention weight arrays as nifty heat-maps. """
    mpl.rc('font', family='VL Gothic')

    torch.manual_seed(42)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    print('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    print('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    vis_loader = torch.utils.data.DataLoader(test_dataset,
                                             num_workers=1,
                                             collate_fn=test_dataset.collater,
                                             batch_sampler=BatchSampler(
                                                 test_dataset,
                                                 None,
                                                 1,
                                                 1,
                                                 0,
                                                 shuffle=False,
                                                 seed=42))

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.load_state_dict(state_dict['model'])
    print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))

    # Store attention weight arrays
    attn_records = list()

    # Iterate over the visualization set
    for i, sample in enumerate(vis_loader):
        if args.cuda:
            sample = utils.move_to_cuda(sample)
        if len(sample) == 0:
            continue

        # Perform forward pass
        output, attn_weights = model(sample['src_tokens'],
                                     sample['src_lengths'],
                                     sample['tgt_inputs'])
        attn_records.append((sample, attn_weights))

        # Only visualize the first 10 sentence pairs
        if i >= 10:
            break

    # Generate heat-maps and store them at the designated location
    if not os.path.exists(args.vis_dir):
        os.makedirs(args.vis_dir)

    for record_id, record in enumerate(attn_records):
        # Unpack
        sample, attn_map = record
        src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx)
        tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx)
        # Convert indices into word tokens
        src_str = src_dict.string(src_ids).split(' ') + ['<EOS>']
        tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>']

        # Generate heat-maps
        attn_map = attn_map.squeeze(dim=0).transpose(1,
                                                     0).cpu().detach().numpy()

        attn_df = pd.DataFrame(attn_map, index=src_str, columns=tgt_str)

        sns.heatmap(attn_df,
                    cmap='Blues',
                    linewidths=0.25,
                    vmin=0.0,
                    vmax=1.0,
                    xticklabels=True,
                    yticklabels=True,
                    fmt='.3f')
        plt.yticks(rotation=0)
        plot_path = os.path.join(args.vis_dir,
                                 'sentence_{:d}.png'.format(record_id))
        plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight')
        plt.clf()

    print(
        'Done! Visualized attention maps have been saved to the \'{:s}\' directory!'
        .format(args.vis_dir))
示例#9
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)  # sets the random seed from pytorch random number generators
    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict, tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(test_dataset, 9999999,
                                                                         args.batch_size, 1, 0, shuffle=False,
                                                                         seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch
        batch_size = sample['src_tokens'].shape[0]  # returns number of rows from sample['src_tokens']
        searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)]
        # beam search with beamsize, max seq length and unkindex --> do this B times

        with torch.no_grad():  # disables gradient calculation
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
            # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent?
            #  encoder_out = self.encoder(src_tokens, src_lengths) decoder_out = self.decoder(tgt_inputs, encoder_out)
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            # vector of ones of length sample['src_tokens'] rows and 1 col filled with eos_idx casted to type sample[
            # 'src_tokens']
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)  # decoder out = decoder(tgt_inputs, encoder_out)

            # __QUESTION 2: Why do we keep one top candidate more than the beam size?
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1, dim=-1)
            # returns largest k elements (here beam_size+1) of the input torch.log(torch.softmax(decoder_out,
            # dim=2) in dimension -1 + 1 is taken because the input is given in logarithmic notation

        #  Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j + 1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j + 1]
                next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
                log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:, i, :]
                lstm_out = encoder_out['src_out'][0][:, i, :]
                final_hidden = encoder_out['src_out'][1][:, i, :]
                final_cell = encoder_out['src_out'][2][:, i, :]
                try:
                    mask = encoder_out['src_mask'][i, :]
                except TypeError:
                    mask = None

                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell,
                                      mask, torch.cat((go_slice[i], next_word)), log_p, 1)

                # add normalization here according to paper
                lp = normalize(node.length)
                score = node.eval()/lp
                # Add diverse
                score = diverse(score, j)
                # __QUESTION 3: Why do we add the node with a negative score?
                searches[i].add(-score, node)

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len - 1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]
            if nodes == []:
                break  # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])
            encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes], dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 2
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size + 1,
                                                    dim=-1)

            #  Create number of beam_size next nodes for every current node
            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    best_candidate = next_candidates[i, :, j]
                    backoff_candidate = next_candidates[i, :, j + 1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j + 1]
                    next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
                    log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                    log_p = log_p[-1]
                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))

                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction?

                    # Store the node as final if EOS is generated
                    if next_word[-1] == tgt_dict.eos_idx:
                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                                                                     next_word)), node.logp,
                                              node.length)
                        # Add length normalization
                        lp = normalize(node.length)
                        score = node.eval()/lp
                        # add diverse
                        score = diverse(score, j)
                        search.add_final(-score, node)

                    # Add the node to current nodes for next iteration
                    else:
                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                                                                     next_word)), node.logp + log_p,
                                              node.length + 1)
                        # Add length normalization
                        lp = normalize(node.length)
                        score = node.eval()/lp
                        # add diverse
                        score = diverse(score, j)
                        search.add(-score, node)

            # __QUESTION 5: What happens internally when we prune our beams?
            # How do we know we always maintain the best sequences?
            for search in searches:
                search.prune()

        # Segment into 1 best sentences
        #best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches])

        # segment 3 best oneliner
        best_sents = torch.stack([n[1].sequence[1:] for s in searches for n in s.get_best()])

        # segment into n best sentences
        #for s in searches:
        #    for n in s.get_best():
        #        best_sents = torch.stack([n[1].sequence[1:].cpu()])
        print('n best sents', best_sents)

        # concatenates a sequence of tensors, gets the one best here, so we should use the n-best (3 best) here
        decoded_batch = best_sents.numpy()

        output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]  # predicts first eos token
            if len(first_eos) > 0:  # checks if the first eos token is not the beginning (position 0)
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        # here: adapt so that it takes the 3-best (aka n-best), % used for no overflow
        for ii, sent in enumerate(output_sentences):
            # all_hyps[int(sample['id'].data[ii])] = sent
            # variant for 3-best
            all_hyps[(int(sample['id'].data[int(ii / 3)]), int(ii % 3))] = sent

    # Write to file (write 3 best per sentence together)
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                # variant for 1-best
                # out_file.write(all_hyps[sent_id] + '\n')
                # variant for 3-best
                out_file.write(all_hyps[(int(sent_id / 3), int(sent_id % 3))] + '\n')
示例#10
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{
        **vars(args),
        **vars(state_dict['args'])
    })
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=1,
                                              collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(
                                                  test_dataset,
                                                  9999999,
                                                  args.batch_size,
                                                  1,
                                                  0,
                                                  shuffle=False,
                                                  seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(
        args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch
        batch_size = sample['src_tokens'].shape[0]
        #print("batch:", batch_size)
        searches = [
            BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx)
            for i in range(batch_size)
        ]
        #print(searches)

        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'],
                                        sample['src_lengths'])
            #print(encoder_out)
            # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent?

            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            #print(go_slice)
            #print(go_slice.size())
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)

            #print(decoder_out)

            # __QUESTION 2: Why do we keep one top candidate more than the beam size?
            # ANS: to anticipate the EOS token?
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)

            #print(log_probs)
            #print(next_candidates)

        # Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j + 1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j + 1]
                next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_candidate, best_candidate)
                log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                    backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:, i, :]
                lstm_out = encoder_out['src_out'][0][:, i, :]
                final_hidden = encoder_out['src_out'][1][:, i, :]
                final_cell = encoder_out['src_out'][2][:, i, :]
                try:
                    mask = encoder_out['src_mask'][i, :]
                except TypeError:
                    mask = None

                #Node here is equivalent to one hypothesis (predicted word)
                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden,
                                      final_cell, mask,
                                      torch.cat(
                                          (go_slice[i], next_word)), log_p, 1)

                #print(node)
                #exit()
                #print(next_word)
                # __QUESTION 3: Why do we add the node with a negative score?

                #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)

                #length_norm_results = log_probs/normalizer

                searches[i].add(-(node.eval()), node)
                #print(node.eval()*(log_p / (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)))
                #exit()
                #print(searches[i].add(-node.eval(), node))

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len - 1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]
            if nodes == []:
                break  # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])
            encoder_out["src_embeddings"] = torch.stack(
                [node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes],
                                       dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes],
                                     dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack(
                    [node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 2
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)
            #print(decoder_out)
            #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)

            #length_norm_results = log_probs/normalizer

            #print(length_norm_results)

            # Create number of beam_size next nodes for every current node
            # ---- i think I have to add length norm here?
            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    best_candidate = next_candidates[i, :, j]
                    #print(best_candidate)
                    backoff_candidate = next_candidates[i, :, j + 1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j + 1]
                    next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                            backoff_candidate, best_candidate)
                    log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_log_p, best_log_p)
                    log_p = log_p[-1]
                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))

                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction?

                    # Store the node as final if EOS is generated
                    if next_word[-1] == tgt_dict.eos_idx:
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp, node.length)
                        #print(node)
                        search.add_final(
                            -node.eval() *
                            (log_p /
                             (((5 + len(next_candidates))**args.alpha)) /
                             ((5 + 1)**args.alpha)), node)

                    # Add the node to current nodes for next iteration
                    else:
                        #This is where I'll add the gamma for adapted beam search?
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp + log_p, node.length + 1)
                        search.add(-node.eval(), node)

            # __QUESTION 5: What happens internally when we prune our beams?
            # How do we know we always maintain the best sequences?

            for search in searches:
                search.prune()
                #print(searches)

        # Segment into sentences
        best_sents = torch.stack(
            [search.get_best()[1].sequence[1:].cpu() for search in searches])
        decoded_batch = best_sents.numpy()

        output_sentences = [
            decoded_batch[row, :] for row in range(decoded_batch.shape[0])
        ]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp
        #print(output_sentences)

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                out_file.write(all_hyps[sent_id] + '\n')
示例#11
0
文件: train.py 项目: keanuk/NLU_CW2
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """

    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()
            '''
            ___QUESTION-1-DESCRIBE-F-START___
            Describe what the following lines of code do.
            '''
            '''
            First, the encoder is constructed. Then the loss is computed using cross entropy. Then the error is propagated backwards through the network. After that, the gradient of the loss function is calculated using pytorch. Then the weights are updated based on the current gradient. Finally, the gradient of all model parameters is set to 0.
            '''
            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()
            '''___QUESTION-1-DESCRIBE-F-END___'''

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(
            ), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, criterion, valid_dataset,
                                    epoch)
        model.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break
示例#12
0
文件: lstm.py 项目: SouthDam/atmt
    def forward(self, tgt_inputs, encoder_out, incremental_state=None):
        """ Performs the forward pass through the instantiated model. """
        # Optionally, feed decoder input token-by-token
        if incremental_state is not None:
            tgt_inputs = tgt_inputs[:, -1:]

        # __LEXICAL: Following code is to assist with the LEXICAL MODEL implementation
        # Recover encoder input
        src_embeddings = encoder_out['src_embeddings']

        src_out, src_hidden_states, src_cell_states = encoder_out['src_out']
        src_mask = encoder_out['src_mask']
        src_time_steps = src_out.size(0)

        # Embed target tokens and apply dropout
        batch_size, tgt_time_steps = tgt_inputs.size()
        tgt_embeddings = self.embedding(tgt_inputs)
        tgt_embeddings = F.dropout(tgt_embeddings, p=self.dropout_in, training=self.training)

        # Transpose batch: [batch_size, tgt_time_steps, num_features] -> [tgt_time_steps, batch_size, num_features]
        tgt_embeddings = tgt_embeddings.transpose(0, 1)

        # Initialize previous states (or retrieve from cache during incremental generation)
        cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
        if cached_state is not None:
            tgt_hidden_states, tgt_cell_states, input_feed = cached_state
        else:
            tgt_hidden_states = [torch.zeros(tgt_inputs.size()[0], self.hidden_size) for i in range(len(self.layers))]
            tgt_cell_states = [torch.zeros(tgt_inputs.size()[0], self.hidden_size) for i in range(len(self.layers))]
            input_feed = tgt_embeddings.data.new(batch_size, self.hidden_size).zero_()

        if self.layers[0].weight_ih.is_cuda:
            tgt_hidden_states = utils.move_to_cuda(tgt_hidden_states)
            tgt_cell_states = utils.move_to_cuda(tgt_cell_states)

        # Initialize attention output node
        attn_weights = tgt_embeddings.data.new(batch_size, tgt_time_steps, src_time_steps).zero_()
        rnn_outputs = []

        # __LEXICAL: Following code is to assist with the LEXICAL MODEL implementation
        # Cache lexical context vectors per translation time-step
        lexical_contexts = []

        for j in range(tgt_time_steps):
            # Concatenate the current token embedding with output from previous time step (i.e. 'input feeding')
            lstm_input = torch.cat([tgt_embeddings[j, :, :], input_feed], dim=1)

            for layer_id, rnn_layer in enumerate(self.layers):
                # Pass target input through the recurrent layer(s)
                tgt_hidden_states[layer_id], tgt_cell_states[layer_id] = \
                    rnn_layer(lstm_input, (tgt_hidden_states[layer_id], tgt_cell_states[layer_id]))

                # Current hidden state becomes input to the subsequent layer; apply dropout
                lstm_input = F.dropout(tgt_hidden_states[layer_id], p=self.dropout_out, training=self.training)

            if self.attention is None:
                input_feed = tgt_hidden_states[-1]
            else:
                input_feed, step_attn_weights = self.attention(tgt_hidden_states[-1], src_out, src_mask)
                attn_weights[:, j, :] = step_attn_weights

                if self.use_lexical_model:
                    # __LEXICAL: Compute and collect LEXICAL MODEL context vectors here
                    # TODO: --------------------------------------------------------------------- CUT
                    pass
                    # TODO: --------------------------------------------------------------------- /CUT

            input_feed = F.dropout(input_feed, p=self.dropout_out, training=self.training)
            rnn_outputs.append(input_feed)

        # Cache previous states (only used during incremental, auto-regressive generation)
        utils.set_incremental_state(
            self, incremental_state, 'cached_state', (tgt_hidden_states, tgt_cell_states, input_feed))

        # Collect outputs across time steps
        decoder_output = torch.cat(rnn_outputs, dim=0).view(tgt_time_steps, batch_size, self.hidden_size)

        # Transpose batch back: [tgt_time_steps, batch_size, num_features] -> [batch_size, tgt_time_steps, num_features]
        decoder_output = decoder_output.transpose(0, 1)

        # Final projection
        decoder_output = self.final_projection(decoder_output)

        if self.use_lexical_model:
            # __LEXICAL: Incorporate the LEXICAL MODEL into the prediction of target tokens here
            pass
            # TODO: --------------------------------------------------------------------- /CUT


        return decoder_output, attn_weights
示例#13
0
文件: train.py 项目: SouthDam/atmt
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """

    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    model_rev = models.build_model(args, tgt_dict, src_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    criterion2 = nn.MSELoss(reduction='sum')
    if args.cuda:
        model = model.cuda()
        model_rev = model_rev.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    utils.load_checkpoint_rev(args, model_rev, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        model_rev.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()

            (output, att), src_out = model(sample['src_tokens'],
                                           sample['src_lengths'],
                                           sample['tgt_inputs'])
            # print(sample['src_lengths'])
            # print(sample['tgt_inputs'].size())
            # print(sample['src_tokens'].size())
            src_inputs = sample['src_tokens'].clone()
            src_inputs[0, 1:src_inputs.size(1)] = sample['src_tokens'][0, 0:(
                src_inputs.size(1) - 1)]
            src_inputs[0, 0] = sample['src_tokens'][0, src_inputs.size(1) - 1]
            tgt_lengths = sample['src_lengths'].clone(
            )  #torch.tensor([sample['tgt_tokens'].size(1)])
            tgt_lengths += sample['tgt_inputs'].size(
                1) - sample['src_tokens'].size(1)
            # print(tgt_lengths)
            # print(sample['num_tokens'])

            # if args.cuda:
            #     tgt_lengths = tgt_lengths.cuda()
            (output_rev,
             att_rev), src_out_rev = model_rev(sample['tgt_tokens'],
                                               tgt_lengths, src_inputs)

            # notice that those are without masks already
            # print(sample['tgt_tokens'].view(-1))
            d, d_rev = get_diff(att, src_out, att_rev, src_out_rev)

            # print(sample['src_tokens'].size())
            # print(sample['tgt_inputs'].size())
            # print(att.size())
            # print(src_out.size())
            # print(acontext.size())
            # print(src_out_rev.size())
            # # print(sample['tgt_inputs'].dtype)
            # # print(sample['src_lengths'])
            # # print(sample['src_tokens'])
            # # print('output %s' % str(output.size()))
            # # print(att)
            # # print(len(sample['src_lengths']))
            # print(d)
            # print(d_rev)
            # print(criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']))
            # print(att2)
            # output=output.cpu().detach().numpy()
            # output=torch.from_numpy(output).cuda()
            # output_rev=output_rev.cpu().detach().numpy()
            # output_rev=torch.from_numpy(output_rev).cuda()
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])  + d +\
                criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) +d_rev
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            # loss_rev = \
            #     criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths)
            # loss_rev.backward()
            # grad_norm_rev = torch.nn.utils.clip_grad_norm_(model_rev.parameters(), args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = (
                loss - d - d_rev).item(), sample['num_tokens'], len(
                    sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            # stats['loss_rev'] += loss_rev.item() * len(sample['src_lengths']) / sample['src_tokens'].size(0) / sample['src_tokens'].size(1)
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, model_rev, criterion,
                                    valid_dataset, epoch)
        model.train()
        model_rev.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, model_rev, optimizer, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break