예제 #1
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    for i, samples in enumerate(progress):
        with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(
            args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
        )
        if should_stop:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
예제 #2
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
예제 #3
0
def train(args, trainer, task, epoch_itr, max_update=math.inf):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )

    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, 'tpu', False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )
    progress.log_args(args, tag='train')

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(',')
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        end_of_epoch = not itr.has_next()
        valid_losses = validate_and_save(args, trainer, task, epoch_itr,
                                         valid_subsets, end_of_epoch)
        if should_stop_early(args,
                             valid_losses[0]) or num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return valid_losses
예제 #4
0
def tmp():
    fs_args, ds_config = gen_ds_fairseq_arg()
    set_seed(fs_args.seed)
    task = tasks.setup_task(fs_args)
    trainer = DsFairseqTrainer(fs_args, ds_config, task)
    batch_itr = BatchIterator(fs_args, task)
    for epoch in batch_itr.train_epoch():
        train(batch_itr, trainer)
        log_dist(
            f'Finish epoch {epoch}, \
            {view_log(metrics.get_smoothed_values("train"))}',
            [0],
        )
        metrics.reset_meters("train")
예제 #5
0
    def train_step(self, sample, is_dummy_batch):
        self.model.train()
        self.model.zero_grad()

        loss, sample_size, logging_output = self.model(sample)

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0
            loss *= 0.0
        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        logging_outputs, (sample_size, ) = torch_reduce_sum(
            self.model.device, [logging_output],
            sample_size,
            ignore=is_dummy_batch)

        final_loss = loss * (dist.get_world_size() / sample_size)
        self.model.backward(final_loss)
        self.model.step()

        logging_output = self.reduce_log(logging_outputs, sample_size)

        if self.model.global_steps % self.model.steps_per_print() != 0:
            return

        log_dist(
            f'Step: {self.model.global_steps}, \
            {view_log(metrics.get_smoothed_values("train_inner"))}',
            [0],
        )
        metrics.reset_meters("train_inner")
예제 #6
0
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
          epoch_itr) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(cfg.optimization.update_freq) else
                   cfg.optimization.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(cfg.common.tensorboard_logdir
                            if distributed_utils.is_master(
                                cfg.distributed_training) else None),
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(cfg.common.wandb_project if distributed_utils.is_master(
            cfg.distributed_training) else None),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        azureml_logging=(cfg.common.azureml_logging
                         if distributed_utils.is_master(
                             cfg.distributed_training) else False),
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = cfg.dataset.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    logger.info("Start iterating over samples")
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(cfg, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
예제 #7
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    should_end_training = False
    for samples in progress:
        with metrics.aggregate('train_inner'):
            try:
                log_output = trainer.train_step(samples)

            except ResetTrainerException:
                trainer._wrapped_criterion = None
                trainer._wrapped_model = None
                trainer._optimizer = None

                logger.info("reset the trainer at {}".format(
                    trainer.get_num_updates()))
                log_output = trainer.train_step(samples)

            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        valid_losses = validate_and_save(args, trainer, task, epoch_itr,
                                         valid_subsets)
        if should_stop_early(args,
                             valid_losses[0]) or num_updates >= max_update:
            should_end_training = True
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return should_end_training
예제 #8
0
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    for i, samples in enumerate(progress):
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')
        if(i==0):
            print('epoch: ', epoch_itr.epoch)
            endeattn_norm=[]
            selfattn_norm=[]
            for m in model.modules():
                if(hasattr(m, 'selfattn_norm')):
                    if(m.selfattn_norm != None):
                        selfattn_norm.append(m.selfattn_norm)
                if(hasattr(m, 'endeattn_norm')):
                    if(m.endeattn_norm != None):
                        endeattn_norm.append(m.endeattn_norm)
            print('self attention norms: ', selfattn_norm)
            print('en/decoder attn norms:', endeattn_norm)
        valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
        if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return valid_losses
예제 #9
0
def train(args,
          trainer,
          task,
          epoch_itr,
          model,
          experiment_path,
          total_samples=None,
          last_epoch_num=0,
          restore=None):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    num_heads = args.decoder_attention_heads
    head_dim = args.decoder_embed_dim // num_heads
    if experiment_path is not None:
        with open(experiment_path, 'r') as f:
            swaps = json.load(f)
        mhr(model, swaps, head_dim, num_heads, epoch_itr.epoch)

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False

    conf = {
        "encoder": [{
            "self_attn": []
        } for i in range(args.encoder_layers)],
        "decoder": [{
            "self_attn": [],
            "enc_attn": []
        } for i in range(args.decoder_layers)]
    }
    attentions = {
        "decoder": [{
            "self_attn": []
        } for i in range(args.decoder_layers)]
    }

    batch_regression = 1.0 - (total_samples / (160239 * 50))
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples,
                                            batch_num=batch_regression)

            if log_output is None:  # OOM, overflow, ...
                continue
        total_samples += model.decoder.layers[0].self_attn.bsz
        batch_regression = 1.0 - (
            total_samples / (160239 * 40)
        )  # need to find more generic way to find total samples and epoch num.

        # Get Confidence for each Head.
        if args.head_confidence_method is not None:
            conf = get_batch_confs(model, conf, args)

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop, val_conf = validate_and_save(
            args, trainer, task, epoch_itr, valid_subsets, end_of_epoch)

        if should_stop:
            break

    if args.head_confidence_method is not None:

        conf = convert_confs(conf, args)

        path = args.save_dir.replace("checkpoints",
                                     "confs") + "-method={0}".format(
                                         args.head_confidence_method)
        try:
            os.mkdir(path, 0o775)
        except:
            pass
        with open(
                args.save_dir.replace("checkpoints", "confs") +
                "-method={0}".format(args.head_confidence_method) +
                "/epoch-{0}.pkl".format(epoch_itr.epoch), 'wb') as fd:
            pickle.dump(conf, fd, protocol=3)

    if args.dynamic_type is not None and args.head_confidence_method is not None:
        conf = val_conf

        restore['enc_self_attn'], last_epoch_num[
            'enc_self_attn'] = dynamic_mhr(model,
                                           int(args.start_dynamic_mhr[0]),
                                           "encoder",
                                           "self_attn",
                                           restore['enc_self_attn'],
                                           int(args.dynamic_swap_frequency[0]),
                                           last_epoch_num['enc_self_attn'],
                                           epoch_itr.epoch + 1,
                                           int(args.dynamic_max_switches[0]),
                                           conf[0],
                                           num_heads,
                                           head_dim,
                                           args.encoder_layers,
                                           local_only=False,
                                           d_type=args.dynamic_type[0],
                                           rest=int(args.dynamic_rest[0]),
                                           end_epoch=int(
                                               args.dynamic_end_epoch[0]))

        restore['dec_self_attn'], last_epoch_num[
            'dec_self_attn'] = dynamic_mhr(model,
                                           int(args.start_dynamic_mhr[1]),
                                           "decoder",
                                           "self_attn",
                                           restore['dec_self_attn'],
                                           int(args.dynamic_swap_frequency[1]),
                                           last_epoch_num['dec_self_attn'],
                                           epoch_itr.epoch + 1,
                                           int(args.dynamic_max_switches[1]),
                                           conf[1],
                                           num_heads,
                                           head_dim,
                                           args.encoder_layers,
                                           local_only=False,
                                           d_type=args.dynamic_type[1],
                                           rest=int(args.dynamic_rest[1]),
                                           end_epoch=int(
                                               args.dynamic_end_epoch[1]))
        restore['dec_enc_attn'], last_epoch_num['dec_enc_attn'] = dynamic_mhr(
            model,
            int(args.start_dynamic_mhr[2]),
            "decoder",
            "encoder_attn",
            restore['dec_enc_attn'],
            int(args.dynamic_swap_frequency[2]),
            last_epoch_num['dec_enc_attn'],
            epoch_itr.epoch + 1,
            int(args.dynamic_max_switches[2]),
            conf[2],
            num_heads,
            head_dim,
            args.encoder_layers,
            local_only=False,
            d_type=args.dynamic_type[2],
            rest=int(args.dynamic_rest[2]),
            end_epoch=int(args.dynamic_end_epoch[2]))

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop, total_samples, restore, last_epoch_num
def train(args, trainer, task, epoch_itr, m_mle=None):
    global model_old
    global model_mle
    model_old = copy.deepcopy(trainer.model)
    if m_mle is None:
        model_mle = model_old
    else:
        model_mle = m_mle

    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    should_end_training = False
    for samples in progress:
        if True: # warning
            valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)

        with metrics.aggregate('train_inner'):
            # Debug: training goes here
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # Log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # Reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        if num_updates > 2 and num_updates % (args.policy_update_per_k_epoch) == 0:  # warning
            del model_old
            torch.cuda.empty_cache()
            model_old = copy.deepcopy(trainer.model)

        valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
        if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
            should_end_training = True
            break

    # Log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # Reset epoch-level meters
    metrics.reset_meters('train')
    return should_end_training
예제 #11
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    logger.info("begin training epoch {}".format(epoch_itr.epoch))
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    time_cost = 0
    for i, samples in enumerate(progress):

        ##### statistic program
        if args.validate_training_performance:
            performance_end_its = args.performance_begin_its + args.performance_its_count - 1
        if args.validate_training_performance and i == args.performance_begin_its:
            processed_tokens = 0

        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            time_begin = time.time()
            log_output = trainer.train_step(samples)
            time_end = time.time()
            if args.validate_training_performance and i >= args.performance_begin_its and i <= performance_end_its:
                time_cost = time_cost + (time_end - time_begin)
            if log_output is None:  # OOM, overflow, ...
                continue
        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(args, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)
        if args.validate_training_performance and i >= args.performance_begin_its:
            for sample in samples:
                net_input = sample['net_input']
                bs, src_lens = net_input['src_tokens'].shape
                processed_tokens += bs * src_lens
        if args.validate_training_performance and i == performance_end_its:
            logger.info("Performance info:")
            logger.info("Begin iteration:{}".format(
                args.performance_begin_its))
            logger.info("End iteration: {}".format(performance_end_its))
            logger.info("Processed_tokens: {}".format(processed_tokens))
            logger.info("Time cost: {} s".format(time_cost))
            logger.info("Throughput:{} tokens/s".format(processed_tokens /
                                                        (time_cost)))
            should_stop = True
        if should_stop:
            break
    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
예제 #12
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    if isinstance(epoch_itr, list):
        itrs = []
        for itr in epoch_itr:
            # Initialize data iterators
            itrs.append(
                itr.next_epoch_itr(
                    fix_batches_to_gpus=args.fix_batches_to_gpus,
                    shuffle=(itr.next_epoch_idx > args.curriculum),
                ))

        update_freq = (args.update_freq[epoch_itr[0].epoch - 1]
                       if epoch_itr[0].epoch <= len(args.update_freq) else
                       args.update_freq[-1])

        grouped_itrs = []
        for itr in itrs:
            grouped_itrs.append(iterators.GroupedIterator(itr, update_freq))

        # not supported
        # if getattr(args, "tpu", False):
        #     itr = utils.tpu_data_loader(itr)

        progress = progress_bar.progress_bar(
            grouped_itrs,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr[0].epoch,
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("simplecluster"),
        )

        trainer.begin_epoch(epoch_itr[0].epoch)

    else:
        # Initialize data iterators
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
        )

        update_freq = (args.update_freq[epoch_itr.epoch - 1]
                       if epoch_itr.epoch <= len(args.update_freq) else
                       args.update_freq[-1])

        itr = iterators.GroupedIterator(itr, update_freq)

        # not supported
        # if getattr(args, "tpu", False):
        #     itr = utils.tpu_data_loader(itr)

        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("tqdm"
                                if not args.no_progress_bar else "simple"),
        )

        trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    for i, samples in enumerate(progress):

        if 'cluster_ids' not in samples[0]['net_input']:
            samples[0]['net_input']['cluster_ids'] = numpy.full(
                (1), 0, dtype=numpy.single)

        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % args.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        if isinstance(itr, list):
            end_of_epoch = not itr[0].has_next()
            valid_losses, should_stop = validate_and_save(
                args, trainer, task, epoch_itr[0], valid_subsets, end_of_epoch)

            if should_stop:
                break
        else:
            end_of_epoch = not itr.has_next()
            valid_losses, should_stop = validate_and_save(
                args, trainer, task, epoch_itr, valid_subsets, end_of_epoch)

            if should_stop:
                break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr[0].epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop