예제 #1
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{
        **vars(args),
        **vars(state_dict['args'])
    })
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=1,
                                              collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(
                                                  test_dataset,
                                                  9999999,
                                                  args.batch_size,
                                                  1,
                                                  0,
                                                  shuffle=False,
                                                  seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(
        args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch
        batch_size = sample['src_tokens'].shape[0]
        #searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)]
        #Task 3: Normalization
        searches = [
            BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx,
                       args.alpha_i / args.alpha_max)
            for i in range(batch_size)
        ]

        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'],
                                        sample['src_lengths'])
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)

            # __QUESTION 1: What happens here and what do 'log_probs' and 'next_candidates' contain?
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)

        # Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                # __QUESTION 2: Why do we need backoff candidates?
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j + 1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j + 1]
                next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_candidate, best_candidate)
                log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                    backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:, i, :]
                lstm_out = encoder_out['src_out'][0][:, i, :]
                final_hidden = encoder_out['src_out'][1][:, i, :]
                final_cell = encoder_out['src_out'][2][:, i, :]
                try:
                    mask = encoder_out['src_mask'][i, :]
                except TypeError:
                    mask = None

                # __QUESTION 3: What happens internally when we add a new beam search node?
                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden,
                                      final_cell, mask,
                                      torch.cat(
                                          (go_slice[i], next_word)), log_p, 1)
                searches[i].add(-node.eval(), node)

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len - 1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]
            if nodes == []:
                break  # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])
            encoder_out["src_embeddings"] = torch.stack(
                [node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes],
                                       dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes],
                                     dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack(
                    [node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 1
            log_probs, next_candidates = torch.topk(torch.log(
                torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size + 1,
                                                    dim=-1)

            # Create number of beam_size next nodes for every current node
            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    # see __QUESTION 2
                    best_candidate = next_candidates[i, :, j]
                    backoff_candidate = next_candidates[i, :, j + 1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j + 1]
                    next_word = torch.where(best_candidate == tgt_dict.unk_idx,
                                            backoff_candidate, best_candidate)
                    log_p = torch.where(best_candidate == tgt_dict.unk_idx,
                                        backoff_log_p, best_log_p)
                    log_p = log_p[-1]
                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))

                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: Why do we treat nodes that generated the end-of-sentence token differently?

                    # Store the node as final if EOS is generated
                    if next_word[-1] == tgt_dict.eos_idx:
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp, node.length)
                        search.add_final(-node.eval(), node)

                    # Add the node to current nodes for next iteration
                    else:
                        node = BeamSearchNode(
                            search, node.emb, node.lstm_out, node.final_hidden,
                            node.final_cell, node.mask,
                            torch.cat((prev_words[i][0].view([1]), next_word)),
                            node.logp + log_p, node.length + 1)
                        #search.add(-node.eval(), node)
                        #Task 4: Diversity of Beam Search
                        diversity_penalty = j * args.gamma_i / args.gamma_max
                        search.add(-node.eval() + diversity_penalty, node)

            # __QUESTION 5: What happens internally when we prune our beams?
            # How do we know we always maintain the best sequences?
            for search in searches:
                search.prune()

        # Segment into sentences
        best_sents = torch.stack(
            [search.get_best()[1].sequence[1:] for search in searches])
        decoded_batch = best_sents.numpy()

        output_sentences = [
            decoded_batch[row, :] for row in range(decoded_batch.shape[0])
        ]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                out_file.write(all_hyps[sent_id] + '\n')
예제 #2
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{
        **vars(args),
        **vars(state_dict['args'])
    })
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=1,
                                              collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(
                                                  test_dataset,
                                                  9999999,
                                                  args.batch_size,
                                                  1,
                                                  0,
                                                  shuffle=False,
                                                  seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(
        args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    for i, sample in enumerate(progress_bar):
        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'],
                                        sample['src_lengths'])
            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)
            prev_words = go_slice
            next_words = None

        for _ in range(args.max_len):
            with torch.no_grad():
                # Compute the decoder output by repeatedly feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)
            # Suppress <UNK>s
            _, next_candidates = torch.topk(decoder_out, 2, dim=-1)
            best_candidates = next_candidates[:, :, 0]
            backoff_candidates = next_candidates[:, :, 1]
            next_words = torch.where(best_candidates == tgt_dict.unk_idx,
                                     backoff_candidates, best_candidates)
            prev_words = torch.cat([go_slice, next_words], dim=1)

        # Segment into sentences
        decoded_batch = next_words.cpu().numpy()
        output_sentences = [
            decoded_batch[row, :] for row in range(decoded_batch.shape[0])
        ]
        assert (len(output_sentences) == len(sample['id'].data))

        # Remove padding
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append([])
        output_sentences = temp

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        # Save translations
        assert (len(output_sentences) == len(sample['id'].data))
        for ii, sent in enumerate(output_sentences):
            all_hyps[int(sample['id'].data[ii])] = sent

    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            for sent_id in range(len(all_hyps.keys())):
                out_file.write(all_hyps[sent_id] + '\n')
예제 #3
0
파일: train.py 프로젝트: h-schaller/atmt
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """
    if args.dropout:
        args.encoder_dropout_in = args.dropout
        args.encoder_dropout_out = args.dropout
        args.decoder_dropout_in = args.dropout
        args.decoder_dropout_out = args.dropout
    print(args)
    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(
            args.data, 'dict{:s}.{:s}'.format(args.bpe_vocab_size,
                                              args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(
            args.data, 'dict{:s}.{:s}'.format(args.bpe_vocab_size,
                                              args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()

            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(
            ), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, criterion, valid_dataset,
                                    epoch)
        model.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break
예제 #4
0
def main(args):
    """ Main function. Visualizes attention weight arrays as nifty heat-maps. """
    mpl.rc('font', family='VL Gothic')

    torch.manual_seed(42)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    print('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    print('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict,
        tgt_dict=tgt_dict)

    vis_loader = torch.utils.data.DataLoader(test_dataset,
                                             num_workers=1,
                                             collate_fn=test_dataset.collater,
                                             batch_sampler=BatchSampler(
                                                 test_dataset,
                                                 None,
                                                 1,
                                                 1,
                                                 0,
                                                 shuffle=False,
                                                 seed=42))

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.load_state_dict(state_dict['model'])
    print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))

    # Store attention weight arrays
    attn_records = list()

    # Iterate over the visualization set
    for i, sample in enumerate(vis_loader):
        if args.cuda:
            sample = utils.move_to_cuda(sample)
        if len(sample) == 0:
            continue

        # Perform forward pass
        output, attn_weights = model(sample['src_tokens'],
                                     sample['src_lengths'],
                                     sample['tgt_inputs'])
        attn_records.append((sample, attn_weights))

        # Only visualize the first 10 sentence pairs
        if i >= 10:
            break

    # Generate heat-maps and store them at the designated location
    if not os.path.exists(args.vis_dir):
        os.makedirs(args.vis_dir)

    for record_id, record in enumerate(attn_records):
        # Unpack
        sample, attn_map = record
        src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx)
        tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx)
        # Convert indices into word tokens
        src_str = src_dict.string(src_ids).split(' ') + ['<EOS>']
        tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>']

        # Generate heat-maps
        attn_map = attn_map.squeeze(dim=0).transpose(1, 0).detach().numpy()

        attn_df = pd.DataFrame(attn_map, index=src_str, columns=tgt_str)

        sns.heatmap(attn_df,
                    cmap='Blues',
                    linewidths=0.25,
                    vmin=0.0,
                    vmax=1.0,
                    xticklabels=True,
                    yticklabels=True,
                    fmt='.3f')
        plt.yticks(rotation=0)
        plot_path = os.path.join(args.vis_dir,
                                 'sentence_{:d}.png'.format(record_id))
        plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight')
        plt.clf()

    print(
        'Done! Visualized attention maps have been saved to the \'{:s}\' directory!'
        .format(args.vis_dir))
예제 #5
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported.')
    torch.manual_seed(args.seed)
    torch.cuda.set_device(args.device_id)
    utils.init_logging(args)

    if args.distributed_world_size > 1:
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            init_method=args.distributed_init_method,
            world_size=args.distributed_world_size,
            rank=args.distributed_rank)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({}) with {} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({}) with {} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(split='train')
    valid_dataset = load_data(split='valid')

    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict).cuda()
    logging.info('Built a model with {} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum').cuda()

    # Build an optimizer and a learning rate schedule
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=20, min_lr=args.min_lr, factor=args.lr_shrink)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler)
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            num_workers=args.num_workers,
            collate_fn=train_dataset.collater,
            batch_sampler=BatchSampler(train_dataset,
                                       args.max_tokens,
                                       args.batch_size,
                                       args.distributed_world_size,
                                       args.distributed_rank,
                                       shuffle=True,
                                       seed=args.seed))

        model.train()
        stats = {
            'loss': 0.,
            'lr': 0.,
            'num_tokens': 0.,
            'batch_size': 0.,
            'grad_norm': 0.,
            'clip': 0.
        }
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=(args.distributed_rank != 0))

        for i, sample in enumerate(progress_bar):
            sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue

            # Forward and backward pass
            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = criterion(output.view(-1, output.size(-1)),
                             sample['tgt_tokens'].view(-1))
            optimizer.zero_grad()
            loss.backward()

            # Reduce gradients across all GPUs
            if args.distributed_world_size > 1:
                utils.reduce_grads(model.parameters())
                total_loss, num_tokens, batch_size = list(
                    map(
                        sum,
                        zip(*utils.all_gather_list([
                            loss.item(), sample['num_tokens'],
                            len(sample['src_tokens'])
                        ]))))
            else:
                total_loss, num_tokens, batch_size = loss.item(
                ), sample['num_tokens'], len(sample['src_tokens'])

            # Normalize gradients by number of tokens and perform clipping
            for param in model.parameters():
                param.grad.data.div_(num_tokens)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()

            # Update statistics for progress bar
            stats['loss'] += total_loss / num_tokens / math.log(2)
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Adjust learning rate based on validation loss
        valid_loss = validate(args, model, criterion, valid_dataset, epoch)
        lr_scheduler.step(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch,
                                  valid_loss)
        if optimizer.param_groups[0]['lr'] <= args.min_lr:
            logging.info('Done training!')
            break
예제 #6
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    utils.init_logging(args)

    # Load dictionaries
    all_dict = Dictionary.load(os.path.join(args.data,
                                            'dict.{}'.format('all')))
    logging.info('Loaded a source dictionary with {} words'.format(
        len(all_dict)))

    # Load dataset
    test_dataset = How2Dataset(src_file=os.path.join(args.data,
                                                     'test.{}'.format('tran')),
                               tgt_file=os.path.join(args.data,
                                                     'test.{}'.format('desc')),
                               all_dict=all_dict,
                               video_file=args.video_file,
                               video_dir=args.video_dir)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        num_workers=args.num_workers,
        collate_fn=test_dataset.collater,
        batch_sampler=BatchSampler(
            test_dataset,
            args.max_tokens,
            1,
            1,  #args.batch_size,args.distributed_world_size,
            args.distributed_rank,
            shuffle=False,
            seed=args.seed))

    # Build model and criterion
    model = models.build_model(args, all_dict).cuda()

    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {}'.format(
        args.checkpoint_path))

    translator = SequenceGenerator(
        model,
        all_dict,
        beam_size=args.beam_size,
        maxlen=args.max_len,
        stop_early=eval(args.stop_early),
        normalize_scores=eval(args.normalize_scores),
        len_penalty=args.len_penalty,
        unk_penalty=args.unk_penalty,
    )

    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)
    for i, sample in enumerate(progress_bar):
        #logging.info(sample)
        sample = utils.move_to_cuda(sample)

        with torch.no_grad():
            hypos = translator.generate(sample['src_tokens'],
                                        sample['src_lengths'],
                                        sample['video_inputs'])
        for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)):
            src_tokens = utils.strip_pad(sample['src_tokens'].data[i, :],
                                         all_dict.pad_idx)
            has_target = sample['tgt_tokens'] is not None
            target_tokens = utils.strip_pad(
                sample['tgt_tokens'].data[i, :],
                all_dict.pad_idx).int().cpu() if has_target else None

            src_str = all_dict.string(src_tokens, args.remove_bpe)
            target_str = all_dict.string(target_tokens,
                                         args.remove_bpe) if has_target else ''

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id,
                                            colored(target_str, 'green')))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    tgt_dict=all_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}'.format(sample_id,
                                            colored(hypo_str, 'blue')))
                    print('P-{}\t{}'.format(
                        sample_id, ' '.join(
                            map(lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist()))))
                    print('A-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(lambda x: str(x.item()), alignment))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    # Convert back to tokens for evaluation with unk replacement and/or without BPE
                    target_tokens = all_dict.binarize(target_str,
                                                      word_tokenize,
                                                      add_if_not_exist=True)
                    print(target_tokens)
예제 #7
0
def main(args):
    """ Main translation function' """
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args_loaded = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    args_loaded.data = args.data
    args = args_loaded
    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))

    # Load dataset
    test_dataset = Seq2SeqDataset(
        src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
        tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
        src_dict=src_dict, tgt_dict=tgt_dict)

    test_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
                                              batch_sampler=BatchSampler(test_dataset, 9999999,
                                                                         args.batch_size, 1, 0, shuffle=False,
                                                                         seed=args.seed))
    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict)
    if args.cuda:
        model = model.cuda()
    model.eval()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)

    # Iterate over the test set
    all_hyps = {}
    adapted_beam_nbest = []
    for i, sample in enumerate(progress_bar):

        # Create a beam search object or every input sentence in batch
        batch_size = sample['src_tokens'].shape[0]
        #print("batch:", batch_size)
        searches = [BeamSearch(args.beam_size, args.max_len - 1, tgt_dict.unk_idx) for i in range(batch_size)]
        #print(searches)

        with torch.no_grad():
            # Compute the encoder output
            encoder_out = model.encoder(sample['src_tokens'], sample['src_lengths'])
            #print(encoder_out)
            # __QUESTION 1: What is "go_slice" used for and what do its dimensions represent?

            go_slice = \
                torch.ones(sample['src_tokens'].shape[0], 1).fill_(tgt_dict.eos_idx).type_as(sample['src_tokens'])
            #print(go_slice)
            #print(go_slice.size())
            if args.cuda:
                go_slice = utils.move_to_cuda(go_slice)

            # Compute the decoder output at the first time step
            decoder_out, _ = model.decoder(go_slice, encoder_out)

            #print(decoder_out)

            # __QUESTION 2: Why do we keep one top candidate more than the beam size?
            # ANS: to anticipate the EOS token?
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)),
                                                    args.beam_size+1, dim=-1)
            #log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)),
                                                    #k=args.nbest+1, dim=-1)

            #print(log_probs)
            #print(next_candidates)

        # Create number of beam_size beam search nodes for every input sentence
        for i in range(batch_size):
            for j in range(args.beam_size):
                best_candidate = next_candidates[i, :, j]
                backoff_candidate = next_candidates[i, :, j+1]
                best_log_p = log_probs[i, :, j]
                backoff_log_p = log_probs[i, :, j+1]
                next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
                log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                log_p = log_p[-1]

                # Store the encoder_out information for the current input sentence and beam
                emb = encoder_out['src_embeddings'][:,i,:]
                lstm_out = encoder_out['src_out'][0][:,i,:]
                final_hidden = encoder_out['src_out'][1][:,i,:]
                final_cell = encoder_out['src_out'][2][:,i,:]
                try:
                    mask = encoder_out['src_mask'][i,:]
                except TypeError:
                    mask = None

                #Node here is equivalent to one hypothesis (predicted word)
                node = BeamSearchNode(searches[i], emb, lstm_out, final_hidden, final_cell,
                                      mask, torch.cat((go_slice[i], next_word)), log_p, 1)


                #print(node)
                #exit()
                #print(next_word)
                # __QUESTION 3: Why do we add the node with a negative score?

                #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)

                #length_norm_results = log_probs/normalizer

                searches[i].add(-(node.eval()), node)
                #print(node.eval()*(log_p / (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)))
                #exit()
                #print(searches[i].add(-node.eval(), node))

        # Start generating further tokens until max sentence length reached
        for _ in range(args.max_len-1):

            # Get the current nodes to expand
            nodes = [n[1] for s in searches for n in s.get_current_beams()]
            if nodes == []:
                break # All beams ended in EOS

            # Reconstruct prev_words, encoder_out from current beam search nodes
            prev_words = torch.stack([node.sequence for node in nodes])
            encoder_out["src_embeddings"] = torch.stack([node.emb for node in nodes], dim=1)
            lstm_out = torch.stack([node.lstm_out for node in nodes], dim=1)
            final_hidden = torch.stack([node.final_hidden for node in nodes], dim=1)
            final_cell = torch.stack([node.final_cell for node in nodes], dim=1)
            encoder_out["src_out"] = (lstm_out, final_hidden, final_cell)
            try:
                encoder_out["src_mask"] = torch.stack([node.mask for node in nodes], dim=0)
            except TypeError:
                encoder_out["src_mask"] = None

            with torch.no_grad():
                # Compute the decoder output by feeding it the decoded sentence prefix
                decoder_out, _ = model.decoder(prev_words, encoder_out)

            # see __QUESTION 2
            log_probs, next_candidates = torch.topk(torch.log(torch.softmax(decoder_out, dim=2)), args.beam_size+1, dim=-1)
            #print(decoder_out)
            #normalizer = (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)

            #length_norm_results = log_probs/normalizer

            #print(length_norm_results)

            # Create number of beam_size next nodes for every current node
            for i in range(log_probs.shape[0]):
                for j in range(args.beam_size):

                    best_candidate = next_candidates[i, :, j]
                    #print(best_candidate)
                    backoff_candidate = next_candidates[i, :, j+1]
                    best_log_p = log_probs[i, :, j]
                    backoff_log_p = log_probs[i, :, j+1]
                    #print(backoff_log_p)
                    next_word = torch.where(best_candidate == tgt_dict.unk_idx, backoff_candidate, best_candidate)
                    log_p = torch.where(best_candidate == tgt_dict.unk_idx, backoff_log_p, best_log_p)
                    log_p = log_p[-1]
                    next_word = torch.cat((prev_words[i][1:], next_word[-1:]))

                    # Get parent node and beam search object for corresponding sentence
                    node = nodes[i]
                    search = node.search

                    # __QUESTION 4: How are "add" and "add_final" different? What would happen if we did not make this distinction?


                    # Store the node as final if EOS is generated
                    if next_word[-1 ] == tgt_dict.eos_idx:
                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                              next_word)), node.logp, node.length)
                        #print(node)
                        search.add_final(-node.eval()*(log_p / (((5+len(next_candidates))**args.alpha)) / ((5+1)**args.alpha)), node)

                    # Add the node to current nodes for next iteration
                    else:
                        #This is where I'll add the gamma for adapted beam search?
                        node = BeamSearchNode(search, node.emb, node.lstm_out, node.final_hidden,
                                              node.final_cell, node.mask, torch.cat((prev_words[i][0].view([1]),
                                              next_word)), node.logp + log_p, node.length + 1)
                        search.add(-node.eval(), node)
                        #search.add(-node.eval() * args.gamma, node)

            # __QUESTION 5: What happens internally when we prune our beams?
            #Question 5: How do we know we always maintain the best sequences?

            for search in searches:
                search.prune()
                #print(searches)

        # Segment into sentences
        # -- get top n search?
        #best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches])
        #print(search.get_best(args.nbest))
        #n_best_sents = torch.stack([sent[1].sequence[1:].cpu() for search in searches for sent in search.get_best(args.nbest)])
        n_best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu()[:args.nbest] for search in searches])
        #n_best_sents = torch.stack([search.get_best()[1].sequence[1:].cpu() for search in searches[:args.nbest]])
        #print(best_sents)
        #exit()
        decoded_batch = n_best_sents.numpy()

        output_sentences = [decoded_batch[row, :] for row in range(decoded_batch.shape[0])]

        # __QUESTION 6: What is the purpose of this for loop?
        temp = list()
        for sent in output_sentences:
            first_eos = np.where(sent == tgt_dict.eos_idx)[0]
            if len(first_eos) > 0:
                temp.append(sent[:first_eos[0]])
            else:
                temp.append(sent)
        output_sentences = temp
        #print(output_sentences)

        # Convert arrays of indices into strings of words
        output_sentences = [tgt_dict.string(sent) for sent in output_sentences]

        if args.nbest > 1:
            adapted_beam_nbest.extend(output_sentences)
        else:
            for ii, sent in enumerate(output_sentences):
                all_hyps[int(sample['id'].data[ii])] = sent

        #print(adapted_beam_nbest)


    # Write to file
    if args.output is not None:
        with open(args.output, 'w') as out_file:
            if args.nbest > 1:
                for sent in adapted_beam_nbest:
                    out_file.write(sent + '\n')
            else:
                for sent_id in range(len(all_hyps.keys())):
                    out_file.write(all_hyps[sent_id] + '\n')
예제 #8
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """
    logging.info('Commencing training!')
    torch.manual_seed(42)
    np.random.seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Check CUDA availability
    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda:0")

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict).to(device)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum').to(device)

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, collate_fn=train_dataset.collater, # num_workers=1,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if len(sample) == 0:
                continue
            model.train()

            for k, v in sample.items():
                if isinstance(v, torch.Tensor):
                    sample[k] = v.to(device)
            '''
            ___QUESTION-1-DESCRIBE-F-START___
            Describe what the following lines of code do.
            
            The following lines of code create a single iteration of training the model. The model is first used to 
             predict a batch, then the cross-entropy loss of the predictions is calculated and backpropagated through the 
              network. It is worth noting that the gradients are normalized in order to ensure we do not encounter any
              exploding gradients. Finaly the network is updated using the Adam optimizer, before the gradients are 
              reset to zero for the next iteration.
            '''
            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])

            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()
            '''___QUESTION-1-DESCRIBE-F-END___'''

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(
            ), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, criterion, valid_dataset,
                                    epoch, device)
        model.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break