예제 #1
0
파일: train.py 프로젝트: mir-am/naturalcc
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args['dataset']['fixed_validation_seed'] is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args['dataset']['fixed_validation_seed'])

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args['dataset']['max_tokens_valid'],
            max_sentences=args['dataset']['max_sentences_valid'],
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args['dataset']
            ['skip_invalid_size_inputs_valid_test'],
            required_batch_size_multiple=args['dataset']
            ['required_batch_size_multiple'],
            seed=args['common']['seed'],
            num_shards=args['distributed_training']['distributed_world_size'],
            shard_id=args['distributed_training']['distributed_rank'],
            num_workers=args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args['common']['log_format'],
            log_interval=args['common']['log_interval'],
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args['common']['tensorboard_logdir'] if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm' if not args['common']['no_progress_bar']
                                else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(
            stats[args['checkpoint']['best_checkpoint_metric']])

    return valid_losses
예제 #2
0
파일: train.py 프로젝트: mir-am/naturalcc
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args['distributed_training']
        ['fix_batches_to_gpus'],
        shuffle=(epoch_itr.next_epoch_idx > args['dataset']['curriculum']),
    )
    update_freq = (args['optimization']['update_freq'][epoch_itr.epoch - 1] if
                   epoch_itr.epoch <= len(args['optimization']['update_freq'])
                   else args['optimization']['update_freq'][-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args['common']['tensorboard_logdir']
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args['dataset']['valid_subset'].split(',')
    max_update = args['optimization']['max_update'] or math.inf
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args['common']['log_interval'] == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset epoch-level meters
            metrics.reset_meters('train_inner')

        if (not args['dataset']['disable_validation']
                and args['checkpoint']['save_interval_updates'] > 0 and
                num_updates % args['checkpoint']['save_interval_updates'] == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
예제 #3
0
def main(args, **unused_kwargs):
    assert args['eval']['path'] is not None, '--path required for evaluation!'

    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])

    LOGGER.info(args)
    # while evaluation, set fraction_using_func_name = 0, namely, not sample from func_name
    args['task']['fraction_using_func_name'] = 0.
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    task = tasks.setup_task(args)

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args['dataset']['gen_subset'])
    dataset = task.dataset(args['dataset']['gen_subset'])

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    LOGGER.info('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args['dataset']['max_tokens'] or 36000,
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args['dataset']['num_shards'],
        shard_id=args['dataset']['shard_id'],
        num_workers=args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'none'),
    )

    code_reprs, query_reprs = [], []
    for sample in progress:
        if 'net_input' not in sample:
            continue
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        batch_code_reprs, batch_query_reprs = models[0](**sample['net_input'])

        code_reprs.extend(batch_code_reprs.tolist())
        query_reprs.extend(batch_query_reprs.tolist())
    code_reprs = np.asarray(code_reprs, dtype=np.float32)
    query_reprs = np.asarray(query_reprs, dtype=np.float32)

    assert code_reprs.shape == query_reprs.shape, (code_reprs.shape,
                                                   query_reprs.shape)
    eval_size = len(
        code_reprs
    ) if args['eval']['eval_size'] == -1 else args['eval']['eval_size']

    k, MRR, topk_idx, topk_prob = 3, [], [], []
    for idx in range(len(dataset) // eval_size):
        code_emb = torch.from_numpy(code_reprs[idx:idx + eval_size, :]).cuda()
        query_emb = torch.from_numpy(query_reprs[idx:idx +
                                                 eval_size, :]).cuda()
        logits = query_emb @ code_emb.t()

        # src_emb_nrom = torch.norm(code_emb, dim=-1, keepdim=True) + 1e-10
        # tgt_emb_nrom = torch.norm(query_emb, dim=-1, keepdim=True) + 1e-10
        # logits = (query_emb / tgt_emb_nrom) @ (code_emb / src_emb_nrom).t()

        correct_scores = logits.diag()
        compared_scores = logits >= correct_scores.unsqueeze(dim=-1)
        mrr = 1 / compared_scores.sum(dim=-1).float()
        MRR.extend(mrr.tolist())
        batch_topk_prob, batch_topk_idx = logits.softmax(dim=-1).topk(k)
        batch_topk_idx = batch_topk_idx + idx * eval_size
        topk_idx.extend(batch_topk_idx.tolist())
        topk_prob.extend(batch_topk_prob.tolist())

    if len(dataset) % eval_size:
        code_emb = torch.from_numpy(code_reprs[-eval_size:, :]).cuda()
        query_emb = torch.from_numpy(query_reprs[-eval_size:, :]).cuda()
        logits = query_emb @ code_emb.t()

        # src_emb_nrom = torch.norm(code_emb, dim=-1, keepdim=True) + 1e-10
        # tgt_emb_nrom = torch.norm(query_emb, dim=-1, keepdim=True) + 1e-10
        # logits = (query_emb / tgt_emb_nrom) @ (code_emb / src_emb_nrom).t()

        correct_scores = logits.diag()
        compared_scores = logits >= correct_scores.unsqueeze(dim=-1)
        last_ids = len(code_reprs) % eval_size
        mrr = 1 / compared_scores.sum(dim=-1).float()[-last_ids:]
        MRR.extend(mrr.tolist())
        batch_topk_prob, batch_topk_idx = logits[-last_ids:].softmax(
            dim=-1).topk(k)
        batch_topk_idx = batch_topk_idx + len(code_reprs) - eval_size
        topk_idx.extend(batch_topk_idx.tolist())
        topk_prob.extend(batch_topk_prob.tolist())

    print('mrr: {:.4f}'.format(np.mean(MRR)))

    for idx, mrr in enumerate(MRR):
        if mrr == 1.0 and topk_prob[idx][0] > 0.8:
            print(
                np.asarray(topk_idx[idx]) + 1,
                [round(porb, 4) for porb in topk_prob[idx]])
예제 #4
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset'][
            'max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args['eval']['replace_unk'])

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=_model_args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']
        ['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm'
                            if not _model_args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    # for sample in tqdm(progress, total=len(progress)):
    sources, hypotheses, references = dict(), dict(), dict()

    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        # prefix_tokens = None
        # if args['eval']['prefix_size'] > 0:
        #     prefix_tokens = sample['target'][:, :args['eval']['prefix_size']]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample)
        # gen_out = task.sequence_generator.generate(model, sample)
        num_generated_tokens = sum(len(h[0]['tokens'])
                                   for h in hypos)  # TODO: warning
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            hypos_tokens = utils.strip_eos(hypos[i][0]['tokens'],
                                           tgt_dict.eos()).int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            # if align_dict is not None:
            #     src_str = task.dataset(args['dataset']['gen_subset']).src.get_original_text(sample_id)
            #     target_str = task.dataset(args['dataset']['gen_subset']).tgt.get_original_text(sample_id)
            # else:
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          args['eval']['remove_bpe'])
            else:
                src_str = ""
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args['eval']['remove_bpe'],
                                             escape_unk=True)

            # hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
            hypo_str = tgt_dict.string(hypos_tokens,
                                       args['eval']['remove_bpe'])

            sources[sample_id] = [src_str]
            hypotheses[sample_id] = [hypo_str]
            references[sample_id] = [target_str]

            if not args['eval']['quiet']:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

                print('H-{}\t{}'.format(sample_id, hypo_str), file=output_file)

    filename = os.path.join(os.path.dirname(__file__), 'config',
                            'predict.json')
    LOGGER.info('write predicted file at {}'.format(filename))
    bleu, rouge_l, meteor = eval_utils.eval_accuracies(hypotheses,
                                                       references,
                                                       filename=filename,
                                                       mode='test')
    LOGGER.info('BLEU: {:.2f}\t ROUGE-L: {:.2f}\t METEOR: {:.2f}'.format(
        bleu, rouge_l, meteor))
예제 #5
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset']['max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['dataset']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=_model_args['dataset']['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm' if not _model_args['common']['no_progress_bar'] else 'none'),
    )

    """
    nohup python -m run.completion.seqrnn.eval > run/completion/seqrnn/case.log 2>&1 &
    """
    sequence_completor = task.build_completor([model], args)
    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        non_pad_idx = sample['net_input']['src_tokens'] > task.target_dictionary.pad()

        with torch.no_grad():
            net_output = sequence_completor.generate([model], sample, prefix_tokens=None)
        lprobs = model.get_normalized_probs(net_output, log_probs=True)

        # from ipdb import set_trace
        # set_trace()

        rank = torch.argmax(lprobs, dim=-1)
        target = model.get_targets(sample, net_output)
        accuracy = 1.0 * ((rank == target) & non_pad_idx).sum(dim=-1) / non_pad_idx.sum(dim=-1)
        for idx, (data_idx, acc) in enumerate(zip(sample['id'], accuracy)):
            if acc > 0.9:
                LOGGER.info(f"{data_idx}: {task.target_dictionary.string(sample['net_input']['src_tokens'][idx, :])}")