예제 #1
0
def tokenize(path, output_path, buffer=1000):
    ''' Parse a file asynchronously '''
    with tempfile.TemporaryDirectory() as tmpdir:
        paths = []
        results = []
        pool = Pool()
        word_counts = Counter()
        basename = os.path.basename(output_path)
        language = os.path.splitext(basename)[1][1:]
        file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer)
        for chunk in sorted(file_chunks):
            output_chunk = f'{chunk}{basename}'
            results.append(pool.apply_async(_tokenize, [language, chunk, output_chunk]))
            paths.append(output_chunk)
        pool.close()

        results = tqdm(
            results,
            unit='chunk',
            dynamic_ncols=True,
            desc=f'Tokenizing {basename}',
            file=sys.stdout # needed to make tqdm_wrap_stdout work
        )
        with tqdm_wrap_stdout():
            for result in results:       
                word_counts += result.get()

            pool.join()

        file_utils.join(paths, output_path)
        return word_counts
예제 #2
0
    def evaluate_epoch(self, epoch, experiment, stats_file, verbose=0):
        ''' Evaluate a single epoch '''
        neg_log_likelihood = metrics.Metric('nll', metrics.format_float)

        def get_description():
            mode_name = 'Test' if self.dataset.split == 'test' else 'Validate'
            description = f'{mode_name} #{epoch}'
            if verbose > 0:
                description += f' {neg_log_likelihood}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout # needed to make tqdm_wrap_stdout work
        )
        with tqdm_wrap_stdout():
            for batch in batches:
                # run the data through the model
                batches.set_description_str(get_description())
                nll, length, stats = self.evaluate(batch)
                self.update_stats(stats, self.stats, self.count)
                if length:
                    neg_log_likelihood.update(nll / length)

        experiment.log_metric('nll', neg_log_likelihood.average)
        self.save_stats(stats_file)
        return neg_log_likelihood.average
예제 #3
0
def get_parse_vocab(path, segmenters, buffer=1000):
    ''' Parse a file asynchronously '''
    with tempfile.TemporaryDirectory() as tmpdir:
        results = []
        pool = Pool()
        vocab = set()
        basename = os.path.basename(path)
        file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer)
        for chunk in sorted(file_chunks):
            results.append(
                pool.apply_async(_get_parse_vocab, [chunk, segmenters]))
        pool.close()

        results = tqdm(
            results,
            unit='chunk',
            dynamic_ncols=True,
            desc=f'Extracting vocab: {basename}',
            file=sys.stdout  # needed to make tqdm_wrap_stdout work
        )
        with tqdm_wrap_stdout():
            for result in results:
                vocab.update(result.get())

            pool.join()

        return vocab
예제 #4
0
def apply_bpe(bpe_path, path, output_path, buffer=1000):
    ''' Parse a file asynchronously '''
    with tempfile.TemporaryDirectory() as tmpdir:
        paths = []
        results = []
        pool = Pool()
        vocab = set()
        basename = os.path.basename(output_path)
        file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer)
        for chunk in sorted(file_chunks):
            output_chunk = f'{chunk}{basename}'
            results.append(pool.apply_async(_apply_bpe, [bpe_path, chunk, output_chunk]))
            paths.append(output_chunk)
        pool.close()

        results = tqdm(
            results,
            unit='chunk',
            dynamic_ncols=True,
            desc=f'BPE encoding {basename}',
            file=sys.stdout # needed to make tqdm_wrap_stdout work
        )
        with tqdm_wrap_stdout():
            for result in results:
                vocab.update(result.get())

            pool.join()

        file_utils.join(paths, output_path)
        return vocab
예제 #5
0
def parse(path, output_path, buffer=1000):
    ''' Parse a file asynchronously '''
    with tempfile.TemporaryDirectory() as tmpdir:
        paths = []
        results = []
        pool = Pool()
        basename = os.path.basename(path)
        file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer)
        for chunk in sorted(file_chunks):
            renamed_chunk = f'{chunk}{basename}'

            os.rename(chunk, renamed_chunk)
            results.append(pool.apply_async(_parse, [renamed_chunk]))
            paths.append(renamed_chunk.replace('bpe.32000.', '') + '.parse')
        pool.close()

        results = tqdm(
            results,
            unit='chunk',
            dynamic_ncols=True,
            desc=f'Parsing {basename}',
            file=sys.stdout  # needed to make tqdm_wrap_stdout work
        )
        with tqdm_wrap_stdout():
            for result in results:
                result.get()

            pool.join()

        file_utils.join(paths, output_path)
예제 #6
0
    def translate_all(self, output_file, epoch, experiment, verbose=0):
        ''' Generate all predictions from the dataset '''
        def get_description():
            description = f'Generate #{epoch}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout  # needed to make tqdm_wrap_stdout work
        )

        with tqdm_wrap_stdout():
            ordered_outputs = []
            for batch in batches:
                # run the data through the model
                batches.set_description_str(get_description())

                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)
                start_event.record()
                sequences = self.translator.translate(
                    batch)  # step to be profiled
                end_event.record()
                torch.cuda.synchronize()
                self.time_profile.append(start_event.elapsed_time(end_event))

                if self.config.timed:
                    continue

                target_sequences = next(iter(sequences.values()))
                for i, example_id in enumerate(batch['example_ids']):
                    outputs = []
                    if verbose > 1:
                        trim = verbose < 2
                        join = verbose < 3
                        for key in sequences.keys():
                            sequence = sequences[key][i]
                            sequence = ' '.join(
                                self.dataset.decode(sequence, join, trim))
                            outputs.append(f'{key}: {sequence}\n')
                        outputs.append(f'+++++++++++++++++++++++++++++\n')
                    else:
                        sequence = target_sequences[i]
                        decoded = ' '.join(
                            self.dataset.decode(sequence, trim=not verbose))
                        outputs.append(f'{decoded}\n')

                    if self.config.order_output:
                        ordered_outputs.append((example_id, outputs))
                    else:
                        output_file.writelines(outputs)

            for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]):  # pylint:disable=consider-using-enumerate
                output_file.writelines(outputs)
예제 #7
0
    async def __call__(self):
        """
        Run the generation!
        """
        entries = self.dataset.entries
        if self.args.data.max_entries:
            entries = entries[:self.args.data.max_entries]

        batch_iterator = tqdm(
            as_completed([self.scheduler.generate(entry)
                          for entry in entries]),
            unit="entry",
            initial=1,
            dynamic_ncols=True,
            desc="Generating",
            total=len(entries),
            file=sys.stdout,  # needed to make tqdm_wrap_stdout work
        )

        sep = "*******\n"
        with tqdm_wrap_stdout():
            example_id = 0
            for result in batch_iterator:
                sample, batch, summary = await result
                summary_length = len(summary["tokens"])
                context = self.generator.tokenizer.decode(
                    summary["tokens"].tolist())
                original = self.generator.tokenizer.decode(
                    batch["tokens"][summary_length:].tolist())
                logging.info(
                    "#%d:\n%scontext\n%s%s\n%soriginal\n%s%s\n%ssample\n%s%s",
                    example_id,
                    sep,
                    sep,
                    context,
                    sep,
                    sep,
                    original,
                    sep,
                    sep,
                    sample,
                )
                example_id += 1

            batch_iterator.close()
예제 #8
0
    def __call__(self) -> float:
        """
        Run the evaluation!
        """
        dataloader = get_dataloader(self.args.data,
                                    self.dataset,
                                    num_devices=len(self.model.device_ids))

        def get_description():
            return f"Eval {self.metric_store}"

        batch_iterator = tqdm(
            dataloader,
            unit="batch",
            initial=1,
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout,  # needed to make tqdm_wrap_stdout work
        )

        with ExitStack() as stack:
            # pylint:disable=no-member
            stack.enter_context(tqdm_wrap_stdout())
            stack.enter_context(chunked_scattering())
            # pylint:enable=no-member

            for batch in batch_iterator:
                try:
                    self.eval_step(batch)
                except RuntimeError as rte:
                    if "out of memory" in str(rte):
                        self.metric_store["oom"].update(1)
                        logging.warning(str(rte))
                    else:
                        batch_iterator.close()
                        raise rte

                batch_iterator.set_description_str(get_description())

            batch_iterator.close()

        return self.metric_store["nll"].average
예제 #9
0
파일: train.py 프로젝트: fallcat/synst
    def train_epoch(self, epoch, experiment, verbose=0):
        ''' Run one training epoch '''
        oom = self.metric_store['oom']
        learning_rate = self.metric_store['lr']
        num_tokens = self.metric_store['num_tok']
        neg_log_likelihood = self.metric_store['nll']

        def try_optimize(i, last=False):
            # optimize if:
            #  1) last and remainder
            #  2) not last and not remainder
            remainder = bool(i % self.config.accumulate_steps)
            if not last ^ remainder:
                next_lr = self.optimize()

                learning_rate.update(next_lr)
                experiment.log_metric('learning_rate', next_lr)
                return True

            return False

        def get_description():
            description = f'Train #{epoch}'
            if verbose > 0:
                description += f' {self.metric_store}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout  # needed to make tqdm_wrap_stdout work
        )
        with tqdm_wrap_stdout():
            i = 1
            nll_per_update = 0.
            length_per_update = 0
            num_tokens_per_update = 0
            for i, batch in enumerate(batches, 1):

                try:

                    nll, length = self.calculate_gradient(batch)
                    did_optimize = try_optimize(i)

                    # record the effective number of tokens
                    num_tokens_per_update += int(sum(batch['input_lens']))
                    num_tokens_per_update += int(sum(batch['target_lens']))

                    if length:
                        # record length and nll
                        nll_per_update += nll
                        length_per_update += length

                    if did_optimize:
                        # advance the experiment step
                        experiment.set_step(experiment.curr_step + 1)

                        num_tokens.update(num_tokens_per_update)
                        neg_log_likelihood.update(nll_per_update /
                                                  length_per_update)

                        experiment.log_metric('num_tokens',
                                              num_tokens_per_update)
                        experiment.log_metric('nll',
                                              neg_log_likelihood.last_value)
                        # experiment.log_metric('max_memory_alloc', torch.cuda.max_memory_allocated()//1024//1024)
                        # experiment.log_metric('max_memory_cache', torch.cuda.max_memory_cached()//1024//1024)

                        nll_per_update = 0.
                        length_per_update = 0
                        num_tokens_per_update = 0

                except RuntimeError as rte:
                    if 'out of memory' in str(rte):
                        torch.cuda.empty_cache()

                        oom.update(1)
                        experiment.log_metric('oom', oom.total)
                        #exit(-1)
                    else:
                        batches.close()
                        raise rte

                if self.should_checkpoint():
                    new_best = False
                    if self.config.early_stopping:
                        with tqdm_unwrap_stdout():
                            new_best = self.evaluate(experiment, epoch,
                                                     verbose)

                    self.checkpoint(epoch, experiment.curr_step, new_best)

                batches.set_description_str(get_description())
                if self.is_done(experiment, epoch):
                    batches.close()
                    break

            try_optimize(i, last=True)
예제 #10
0
    def translate_all(self,
                      output_file,
                      stats_file,
                      example_file,
                      epoch,
                      experiment,
                      verbose=0):
        ''' Generate all predictions from the dataset '''
        def get_description():
            description = f'Generate #{epoch}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout  # needed to make tqdm_wrap_stdout work
        )

        with tqdm_wrap_stdout():
            ordered_outputs = []
            with torch.no_grad():
                self.model.eval()
                count = 0
                for batch in batches:
                    # run the data through the model
                    batches.set_description_str(get_description())

                    result = self.model(batch)
                    #
                    # stats
                    encoder_stats = probe(
                        result['encoder_attn_weights_tensor'])
                    decoder_stats = probe(
                        result['decoder_attn_weights_tensor'])
                    enc_dec_stats = probe(
                        result['enc_dec_attn_weights_tensor'])

                    train_stats = {
                        'encoder_stats': {
                            stats_type: encoder_stats[stats_type].view(
                                self.num_layers, self.num_heads,
                                -1).cpu().numpy()
                            for stats_type in STATS_TYPES
                        },
                        'decoder_stats': {
                            stats_type: decoder_stats[stats_type].view(
                                self.num_layers, self.num_heads,
                                -1).cpu().numpy()
                            for stats_type in STATS_TYPES
                        },
                        'enc_dec_stats': {
                            stats_type: enc_dec_stats[stats_type].view(
                                self.num_layers, self.num_heads,
                                -1).cpu().numpy()
                            for stats_type in STATS_TYPES
                        }
                    }

                    self.update_stats2(train_stats, self.train_stats,
                                       self.train_count)

                    sequences, test_stats = self.translator.translate(batch)

                    # self.update_stats(test_stats, self.test_stats, self.test_count)

                    if self.config.timed:
                        continue

                    target_sequences = next(iter(sequences.values()))
                    new_targets = []
                    for i, example_id in enumerate(batch['example_ids']):
                        # print("example_id", example_id)
                        if example_id == 430:
                            print("saved")
                            train_tensors = {
                                'encoder':
                                result['encoder_attn_weights_tensor'].cpu(
                                ).numpy().tolist(),
                                'decoder':
                                result['decoder_attn_weights_tensor'].cpu(
                                ).numpy().tolist(),
                                'enc_dec':
                                result['enc_dec_attn_weights_tensor'].cpu().
                                numpy().tolist()
                            }
                            json.dump(train_tensors, example_file)
                        outputs = []
                        if verbose > 1:
                            trim = verbose < 2
                            join = verbose < 3
                            for key in sequences.keys():
                                sequence = sequences[key][i]
                                sequence = ' '.join(
                                    self.dataset.decode(sequence, join, trim))
                                outputs.append(f'{key}: {sequence}\n')
                            outputs.append(f'+++++++++++++++++++++++++++++\n')
                        else:
                            sequence = target_sequences[i]
                            new_targets.append(torch.LongTensor(sequence))
                            decoded = ' '.join(
                                self.dataset.decode(sequence,
                                                    trim=not verbose))
                            outputs.append(f'{decoded}\n')

                        if self.config.order_output:
                            ordered_outputs.append((example_id, outputs))
                        else:
                            output_file.writelines(outputs)

                    self.dataset.collate_field(batch, 'target', new_targets)
                    result = self.model(batch)

                    # stats
                    encoder_stats = probe(
                        result['encoder_attn_weights_tensor'])
                    decoder_stats = probe(
                        result['decoder_attn_weights_tensor'])
                    enc_dec_stats = probe(
                        result['enc_dec_attn_weights_tensor'])
                    test_stats = {
                        'encoder_stats': {
                            stats_type: encoder_stats[stats_type].view(
                                self.num_layers, self.num_heads,
                                -1).cpu().numpy()
                            for stats_type in STATS_TYPES
                        },
                        'decoder_stats': {
                            stats_type: decoder_stats[stats_type].view(
                                self.num_layers, self.num_heads,
                                -1).cpu().numpy()
                            for stats_type in STATS_TYPES
                        },
                        'enc_dec_stats': {
                            stats_type: enc_dec_stats[stats_type].view(
                                self.num_layers, self.num_heads,
                                -1).cpu().numpy()
                            for stats_type in STATS_TYPES
                        }
                    }

                    self.update_stats2(test_stats, self.test_stats,
                                       self.test_count)
                    count += 1
                    if count == 50:
                        break

            for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]):  # pylint:disable=consider-using-enumerate
                output_file.writelines(outputs)

        self.save_stats2(stats_file)
예제 #11
0
    def translate_all(self, output_file, epoch, experiment, verbose=0):
        ''' Generate all predictions from the dataset '''
        def get_description():
            description = f'Generate #{epoch}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout # needed to make tqdm_wrap_stdout work
        )

        with tqdm_wrap_stdout():
            self.model.eval()
            ordered_outputs = []
            for batch in batches:
                # print("in probe new translate", flush=True)
                # run the data through the model
                batches.set_description_str(get_description())
                sequences, attn_weights_tensors_dict = self.translator.translate(batch)

                if self.config.timed:
                    continue

                target_sequences = next(iter(sequences.values()))
                encoder_attn_weights_tensor = attn_weights_tensors_dict['encoder_attn_weights_tensor']
                new_targets = []
                output_sentences = []
                source_sentences = []
                for i, example_id in enumerate(batch['example_ids']):
                    outputs = []
                    if verbose > 1:
                        trim = verbose < 2
                        join = verbose < 3
                        for key in sequences.keys():
                            sequence = sequences[key][i]
                            sequence = ' '.join(self.dataset.decode(sequence, join, trim))
                            outputs.append(f'{key}: {sequence}\n')
                        outputs.append(f'+++++++++++++++++++++++++++++\n')
                    else:
                        sequence = target_sequences[i]
                        new_targets.append(torch.LongTensor(sequence))
                        decoded = ' '.join(self.dataset.decode(sequence, trim=not verbose))
                        outputs.append(f'{decoded}\n')
                        output_sentence = ' '.join(self.dataset.decode(sequence, join=False, trim=not verbose))
                        output_sentences.append(output_sentence)
                        source_sentence = ' '.join(self.dataset.decode(batch['inputs'][i], join=False, trim=not verbose))
                        source_sentences.append(source_sentence)

                        # Encoder heatmap
                        # print("saving encoder heatmap")
                        for j in range(encoder_attn_weights_tensor.shape[0]):
                            for k in range(encoder_attn_weights_tensor.shape[1]):
                                attn_filename = f'encoder_attn_weights{example_id}_l{j}_h{k}.png'
                                attn_path = os.path.join(self.config.output_directory, attn_filename)
                                save_attention(source_sentence, source_sentence,
                                               encoder_attn_weights_tensor[j][k].cpu().numpy(), attn_path)

                    if self.config.order_output:
                        ordered_outputs.append((example_id, outputs))
                    else:
                        output_file.writelines(outputs)

                self.dataset.collate_field(batch, 'target', new_targets)
                result = self.model(batch)
                # Decoder heatmap
                # print("saving decoder heatmap")
                for i, example_id in enumerate(batch['example_ids']):
                    for j in range(result['decoder_attn_weights_tensor'].shape[0]):
                        for k in range(result['decoder_attn_weights_tensor'].shape[1]):
                            attn_filename = f'decoder_attn_weights{example_id}_l{j}_h{k}.png'
                            attn_path = os.path.join(self.config.output_directory, attn_filename)
                            save_attention(output_sentences[i], output_sentences[i],
                                           result['decoder_attn_weights_tensor'][j][k].cpu().numpy(), attn_path)
                            attn_filename = f'enc_dec_attn_weights{example_id}_l{j}_h{k}.png'
                            attn_path = os.path.join(self.config.output_directory, attn_filename)
                            save_attention(source_sentences[i], '<SOS> ' + output_sentences[i],
                                           result['enc_dec_attn_weights_tensor'][j][k].cpu().numpy(), attn_path)

            for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]): # pylint:disable=consider-using-enumerate
                output_file.writelines(outputs)
예제 #12
0
    def translate_all(self, output_file, enc_off_diagonal_output_file, dec_off_diagonal_output_file, epoch, experiment, verbose=0):
        ''' Generate all predictions from the dataset '''
        def get_description():
            description = f'Generate #{epoch}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout # needed to make tqdm_wrap_stdout work
        )

        with tqdm_wrap_stdout():
            self.model.eval()
            ordered_outputs = []
            for batch in batches:
                # print("in probe new translate", flush=True)
                # run the data through the model
                batches.set_description_str(get_description())
                sequences = self.translator.translate(batch)

                if self.config.timed:
                    continue

                target_sequences = next(iter(sequences.values()))
                new_targets = []
                output_sentences = []
                source_sentences = []
                for i, example_id in enumerate(batch['example_ids']):
                    outputs = []
                    if verbose > 1:
                        trim = verbose < 2
                        join = verbose < 3
                        for key in sequences.keys():
                            sequence = sequences[key][i]
                            sequence = ' '.join(self.dataset.decode(sequence, join, trim))
                            outputs.append(f'{key}: {sequence}\n')
                        outputs.append(f'+++++++++++++++++++++++++++++\n')
                    else:
                        sequence = target_sequences[i]
                        new_targets.append(torch.LongTensor(sequence))
                        decoded = ' '.join(self.dataset.decode(sequence, trim=not verbose))
                        outputs.append(f'{decoded}\n')
                        output_sentence = ' '.join(self.dataset.decode(sequence, join=False, trim=not verbose))
                        output_sentences.append(output_sentence)
                        source_sentence = ' '.join(self.dataset.decode(batch['inputs'][i], join=False, trim=not verbose))
                        source_sentences.append(source_sentence)

                        # Encoder heatmap
                        # print("saving encoder heatmap")
                        # for j in range(encoder_attn_weights_tensor.shape[0]):
                        #     for k in range(encoder_attn_weights_tensor.shape[1]):
                        #         attn_filename = f'encoder_attn_weights{example_id}_l{j}_h{k}.png'
                        #         attn_path = os.path.join(self.config.output_directory, attn_filename)
                        #         save_attention(source_sentence, source_sentence,
                        #                        encoder_attn_weights_tensor[j][k].cpu().numpy(), attn_path)

                    if self.config.order_output:
                        ordered_outputs.append((example_id, outputs))
                    else:
                        output_file.writelines(outputs)

                self.dataset.collate_field(batch, 'target', new_targets)
                result = self.model(batch)
                # Decoder heatmap
                # print("saving decoder heatmap")
                for i, example_id in enumerate(batch['example_ids']):
                    # print("result.keys()", result.keys())
                    # print("result['encoder_attn_weights_tensor']", result['encoder_attn_weights_tensor'].shape)
                    for coder in ['encoder', 'decoder']:
                        attn_weights_shape = result[coder + '_attn_weights_tensor'].shape
                        attn_weights = result[coder + '_attn_weights_tensor'].view(-1,
                                                                                  attn_weights_shape[2],
                                                                                  attn_weights_shape[3])
                        indices_q = torch.round(torch.arange(attn_weights_shape[2], dtype=torch.float32,
                                                 device=attn_weights.get_device()).view(1, -1) * self.dataset.word_count_ratio)
                        argmax_weights = torch.argmax(attn_weights, dim=2)
                        # print("argmax_weights", argmax_weights)
                        max_weights = torch.max(attn_weights, dim=2)[0]  #attn_weights[argmax_weights]
                        # print("max_weights", max_weights.shape)
                        distance = torch.abs(argmax_weights.type_as(indices_q) - indices_q)
                        # print("attn_weights", attn_weights.shape)
                        # print("distance", distance.shape)
                        # print(distance)
                        # print("distance >= threshold", (distance >= self.config.off_diagonal_distance_threshold).shape)
                        # print(distance >= self.config.off_diagonal_distance_threshold, torch.sum(distance >= self.config.off_diagonal_distance_threshold))
                        # print("max_weights[distance >= threshold]", max_weights[distance >= self.config.off_diagonal_distance_threshold].shape, max_weights[distance >= 1])

                        max_prob = torch.max(max_weights[distance >= self.config.off_diagonal_distance_threshold])
                        argmax_offset = torch.max(distance)
                        number = torch.sum(distance >= self.config.off_diagonal_distance_threshold)
                        #self.config.off_diagonal_threshold_param

                        if self.config.off_diagonal_threshold_type == "number":
                            # print("number")
                            idx = int(torch.round(number.to(torch.float32) / float(attn_weights.shape[0] *
                                                                               attn_weights.shape[1]) *
                                              self.config.off_diagonal_bins).cpu().item())
                            self.number_frac_dict[coder][idx] += 1
                            self.number_frac_list_dict[coder][idx].append(example_id)
                        # elif self.config.off_diagonal_threshold_type == "offset":
                        #     print("offset")
                        #     if argmax_offset >= self.config.off_diagonal_threshold_param:
                        #         self.off_diagonal.append(example_id)
                        #     else:
                        #         self.non_off_diagonal.append(example_id)
                        # else:  # prob
                        #     print("prob")
                        #     if max_prob >= self.config.off_diagonal_threshold_param:
                        #         self.off_diagonal.append(example_id)
                        #     else:
                        #         self.non_off_diagonal.append(example_id)

                    #self.dataset.word_count_ratio

                    # for j in range(result['decoder_attn_weights_tensor'].shape[0]):
                    #     for k in range(result['decoder_attn_weights_tensor'].shape[1]):
                    #         attn_filename = f'decoder_attn_weights{example_id}_l{j}_h{k}.png'
                    #         attn_path = os.path.join(self.config.output_directory, attn_filename)
                    #         save_attention(output_sentences[i], output_sentences[i],
                    #                        result['decoder_attn_weights_tensor'][j][k].cpu().numpy(), attn_path)
                    #         attn_filename = f'enc_dec_attn_weights{example_id}_l{j}_h{k}.png'
                    #         attn_path = os.path.join(self.config.output_directory, attn_filename)
                    #         save_attention(source_sentences[i], '<PAD>' + output_sentences[i],
                    #                        result['enc_dec_attn_weights_tensor'][j][k].cpu().numpy(), attn_path)

            # print("num off diagonal", len(self.off_diagonal))
            # print("num non off diagonal", len(self.non_off_diagonal))

            pp = pprint.PrettyPrinter()
            print("---------Encoder---------")
            print("number_dict")
            pp.pprint([(k, self.number_dict['encoder'][k]) for k in sorted(self.number_dict['encoder'].keys())])
            print("number_frac_dict")
            pp.pprint([(k, self.number_frac_dict['encoder'][k]) for k in sorted(self.number_frac_dict['encoder'].keys())])
            print("---------Decoder---------")
            print("number_dict")
            pp.pprint([(k, self.number_dict['decoder'][k]) for k in sorted(self.number_dict['decoder'].keys())])
            print("number_frac_dict")
            pp.pprint(
                [(k, self.number_frac_dict['decoder'][k]) for k in sorted(self.number_frac_dict['decoder'].keys())])

            for k in sorted(self.number_frac_dict['encoder'].keys()):
                enc_off_diagonal_output_file.write(str(k) + "\t" + " ".join(str(x) for x in self.number_frac_list_dict['encoder'][k]) + "\n")
            for k in sorted(self.number_frac_dict['decoder'].keys()):
                dec_off_diagonal_output_file.write(
                    str(k) + "\t" + " ".join(str(x) for x in self.number_frac_list_dict['decoder'][k]) + "\n")
            # off_diagonal_output_file.write(str(len(self.off_diagonal)) + "\t" + " ".join([str(x) for x in self.off_diagonal]) + "\n")
            # off_diagonal_output_file.write(str(len(self.non_off_diagonal)) + "\t" + " ".join([str(x) for x in self.non_off_diagonal]) + "\n")

            for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]): # pylint:disable=consider-using-enumerate
                output_file.writelines(outputs)
예제 #13
0
    def train_epoch(self, epoch, experiment, verbose=0):
        ''' Run one training epoch '''
        oom = self.metric_store['oom']
        learning_rate = self.metric_store['lr']
        num_tokens = self.metric_store['num_tok']
        neg_log_likelihood = self.metric_store['nll']
        perplexity = self.metric_store['ppl']

        def try_optimize(i, curr_step, last=False):
            # optimize if:
            #  1) last and remainder
            #  2) not last and not remainder
            remainder = bool(i % self.config.accumulate_steps)
            if not last ^ remainder:
                next_lr = self.optimize(curr_step)

                learning_rate.update(next_lr)
                experiment.log_metric('learning_rate', next_lr)
                return True

            return False

        def get_description():
            description = f'Train #{epoch}'
            if verbose > 0:
                description += f' {self.metric_store}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout 
        )
        with tqdm_wrap_stdout():
            i = 1
            nll_per_update = 0.
            length_per_update = 0
            num_tokens_per_update = 0
            cnter = 0
            for i, batch in enumerate(batches, 1):

                if type(batch) is not torch.Tensor:
                    # concatenated dataset
                    batch = ([b.squeeze(0) for b in batch])
                else:
                    batch = (batch,)

                try:
                    # if batch.shape[0] == self.config.batch_length:
                    if True:
                        
                        nll, length = self.calculate_gradient(batch)
                        did_optimize = try_optimize(i, experiment.curr_step )

                        # record the effective number of tokens
                        num_tokens_per_update += length

                        if length:
                            # record length and nll
                            nll_per_update += nll
                            length_per_update += length

                        if did_optimize:
                            # advance the experiment step
                            experiment.set_step(experiment.curr_step + 1)

                            num_tokens.update(num_tokens_per_update)
                            nll = nll_per_update / length_per_update
                            neg_log_likelihood.update(nll)
                            perplexity.update(np.exp(nll))

                            experiment.log_metric('num_tokens', num_tokens_per_update)
                            experiment.log_metric('nll', neg_log_likelihood.last_value)
                            experiment.log_metric('ppl', perplexity.last_value)

                            nll_per_update = 0.
                            length_per_update = 0
                            num_tokens_per_update = 0

                except RuntimeError as rte:
                    if 'out of memory' in str(rte):
                        torch.cuda.empty_cache()

                        oom.update(1)
                        experiment.log_metric('oom', oom.total)
                        #exit(-1)
                        raise rte
                    else:
                        batches.close()
                        raise rte

                if self.should_checkpoint():
                    new_best = False
                    if self.config.early_stopping:
                        with tqdm_unwrap_stdout():
                            new_best = self.evaluate(experiment, epoch, verbose)

                    self.checkpoint(epoch, experiment.curr_step, new_best)

                batches.set_description_str(get_description())
                if self.is_done(experiment, epoch):
                    batches.close()
                    break

            try_optimize(i, experiment.curr_step, last=True)
예제 #14
0
    def translate_all(self,
                      output_file,
                      stats_file,
                      epoch,
                      experiment,
                      verbose=0):
        ''' Generate all predictions from the dataset '''
        def get_description():
            description = f'Generate #{epoch}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout  # needed to make tqdm_wrap_stdout work
        )

        with tqdm_wrap_stdout():
            ordered_outputs = []
            with torch.no_grad():
                self.model.eval()
                for batch in batches:
                    # run the data through the model
                    batches.set_description_str(get_description())

                    result = self.model(batch)

                    enc_dec_stats = probe(
                        result['enc_dec_attn_weights_tensor'])
                    train_stats = {
                        'enc_dec_stats': {
                            stats_type: enc_dec_stats[stats_type].view(
                                self.num_layers, self.num_heads, -1)
                            for stats_type in STATS_TYPES
                        }
                    }

                    self.update_stats(train_stats, self.train_stats,
                                      self.train_count)

                    # translate
                    sequences, attn_weights_tensors = self.translator.translate(
                        batch)
                    target_sequences = next(iter(sequences.values()))
                    new_targets = [
                        torch.LongTensor(sequence)
                        for sequence in target_sequences
                    ]

                    # new data with new targets
                    self.dataset.collate_field(batch, 'target', new_targets)

                    # model(new-batch)
                    result = self.model(batch)

                    enc_dec_stats = probe(
                        result['enc_dec_attn_weights_tensor'])
                    test_stats = {
                        'enc_dec_stats': {
                            stats_type: enc_dec_stats[stats_type].view(
                                self.num_layers, self.num_heads, -1)
                            for stats_type in STATS_TYPES
                        }
                    }

                    self.update_stats(test_stats, self.test_stats,
                                      self.test_count)

                    if self.config.timed:
                        continue

        self.save_stats(stats_file)
예제 #15
0
    def __call__(self):
        """
        Run the training!
        """
        # Must be called first
        self.try_init_amp()

        model = self.modules["model"]
        optimizer = self.modules["optimizer"]
        scheduler = self.modules["scheduler"]

        if self.args.optim.use_gradient_checkpointing:
            model.enable_gradient_checkpointing()

        model = nn.DataParallel(model)
        dataloader = get_dataloader(
            self.args.data,
            self.dataset,
            num_devices=len(model.device_ids),
            shuffle=True,
        )

        def get_description():
            return f"Train {self.metric_store}"

        max_steps = self.args.optim.max_steps
        accumulation_steps = self.args.optim.gradient_accumulation_steps
        progress = tqdm(
            unit="step",
            initial=self.step,
            dynamic_ncols=True,
            desc=get_description(),
            total=max_steps,
            file=sys.stdout,  # needed to make tqdm_wrap_stdout work
        )

        with ExitStack() as stack:
            # pylint:disable=no-member
            stack.enter_context(tqdm_wrap_stdout())
            stack.enter_context(chunked_scattering())
            stack.enter_context(self.experiment.train())
            # pylint:enable=no-member

            if self.args.optim.early_stopping:
                # If using early stopping, must evaluate regularly to determine
                # if training should stop early, so setup an Evaluator
                eval_args = copy.deepcopy(self.args)
                eval_args.data.batch_size = self.args.optim.eval_batch_size

                evaluator = Evaluator(eval_args)
                evaluator.model = model
                evaluator.load_dataset("validation")
                evaluator.initialize_experiment(experiment=self.experiment)

                # Make sure we are tracking validation nll
                self.metric_store.add(
                    metrics.Metric("vnll", "format_float", "g(m)"))

                # And store a local variable for easy access
                vnll_metric = self.metric_store["vnll"]

            loss = 0
            num_tokens = 0
            for step, batch in enumerate(cycle(dataloader), 1):
                try:
                    step_loss = self.compute_gradients_and_loss(
                        batch, model, optimizer)
                    run_optimizer = (step % accumulation_steps) == 0

                    if run_optimizer:
                        # Run an optimization step
                        optimizer.step()
                        scheduler.step()  # Update learning rate schedule
                        model.zero_grad()

                    # Update loss and num tokens after running an optimization
                    # step, in case it results in an out of memory error
                    loss += step_loss
                    num_tokens += batch["num_tokens"]

                    if run_optimizer:
                        # Since we ran the optimizer, increment current step
                        self.step += 1
                        self.experiment.set_step(self.step)
                        progress.update()

                        # update our metrics as well
                        self.update_metrics(
                            loss / accumulation_steps,
                            num_tokens,
                            scheduler.get_lr()[0],
                        )
                        num_tokens = 0
                        loss = 0

                        # and finally check if we should save
                        if (self.args.save_steps > 0
                                and self.step % self.args.save_steps == 0):
                            # First save the current checkpoint
                            self.save()

                            # Then if we are implementing early stopping, see
                            # if we achieved a new best
                            if self.args.optim.early_stopping:
                                evaluator.reset_metrics()
                                with ExitStack() as eval_stack:
                                    # pylint:disable=no-member
                                    eval_stack.enter_context(
                                        tqdm_unwrap_stdout())
                                    eval_stack.enter_context(
                                        release_cuda_memory(
                                            collect_tensors(optimizer.state)))
                                    # pylint:enable=no-member

                                    vnll = evaluator()
                                    vnll_metric.update(vnll)

                                    # Save the updated metrics
                                    self.save_metrics()

                                    if vnll == vnll_metric.min:
                                        self.on_new_best()

                                # Try to combat OOM errors caused by doing evaluation
                                # in the same loop with training. This manifests in out
                                # of memory errors after the first or second evaluation
                                # run.
                                refresh_cuda_memory()

                            if not self.prune_checkpoints():
                                logging.info("Stopping early")
                                break

                            if self.step >= max_steps:
                                logging.info("Finished training")
                                break

                except RuntimeError as rte:
                    if "out of memory" in str(rte):
                        self.metric_store["oom"].update(1)
                        logging.warning(str(rte))
                    else:
                        progress.close()
                        raise rte

                progress.set_description_str(get_description())

            progress.close()