Exemple #1
0
def single_main(args, init_distributed=False):
    assert args['dataset']['max_tokens'] is not None or args['dataset']['max_sentences'] is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # 0. Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])
    set_seed.set_seed(args['common']['seed'])
    if init_distributed:
        args['distributed_training'][
            'distributed_rank'] = distributed_utils.distributed_init(args)

    # Verify checkpoint directory
    if distributed_utils.is_master(args):
        save_dir = args['checkpoint']['save_dir']
        checkpoint_utils.verify_checkpoint_directory(save_dir)
        PathManager.rm(os.path.join(
            save_dir, '*.pt'))  # this code will remove pre-trained models

    # 1. Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # 2. Load valid dataset (we load training data below, based on the latest checkpoint)
    task.load_dataset(args['dataset']['valid_subset'], combine=False, epoch=1)

    # 3. Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    LOGGER.info(model)
    LOGGER.info('model {}, criterion {}'.format(args['model']['arch'],
                                                criterion.__class__.__name__))
    LOGGER.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # 4. Build trainer
    trainer = Trainer(args, task, model, criterion)
    LOGGER.info('training on {} GPUs'.format(
        args['distributed_training']['distributed_world_size']))
    LOGGER.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args['dataset']['max_tokens'],
            args['dataset']['max_sentences'],
        ))

    # 5. Load the latest checkpoint if one is available and restore the corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args,
                                                              trainer,
                                                              combine=False)

    # 6. Train until the learning rate gets too small
    max_epoch = args['optimization']['max_epoch'] or math.inf
    max_update = args['optimization']['max_update'] or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    valid_subsets = args['dataset']['valid_subset'].split(',')
    while (lr > args['optimization']['min_lr']
           and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args['dataset']['disable_validation'] and epoch_itr.epoch % args[
                'dataset']['validate_interval'] == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args['checkpoint']['save_interval'] == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        # early stop
        if should_stop_early(args, valid_losses[0]):
            LOGGER.info(
                'early stop since valid performance hasn\'t improved for last {} runs'
                .format(args['checkpoint']['patience']))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            combine=False,  # TODO to be checked
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in args['task']['data']),
        )

    train_meter.stop()
    LOGGER.info('done training in {:.1f} seconds'.format(train_meter.sum))
Exemple #2
0
def single_main(args, init_distributed=False):
    assert args['dataset']['max_tokens'] is not None or args['dataset']['max_sentences'] is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # 0. Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])
    random.seed(args['common']['seed'])
    np.random.seed(args['common']['seed'])
    torch.manual_seed(args['common']['seed'])
    torch.cuda.manual_seed(args['common']['seed'])
    if init_distributed:
        args['distributed_training'][
            'distributed_rank'] = distributed_utils.distributed_init(args)

    # Verify checkpoint directory
    if distributed_utils.is_master(args):
        save_dir = args['checkpoint']['save_dir']
        checkpoint_utils.verify_checkpoint_directory(save_dir)
        remove_files(save_dir,
                     'pt')  # this code will remove pre-trained models

    # 1. Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # 2. Load valid dataset (we load training data below, based on the latest checkpoint)
    # calculate accuracy for decay learning rate
    task.load_dataset(args['dataset']['valid_subset'], combine=False, epoch=1)
    # # compute meteor to select model
    # task.load_dataset(args['dataset']['dev_subset'], combine=False, epoch=1)
    # # load dev/ref.txt
    # dev_refs = load_refs(os.path.join(args['task']['data'], args['dataset']['dev_ref_subset']))

    # 3. Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    LOGGER.info(model)
    LOGGER.info('model {}, criterion {}'.format(args['model']['arch'],
                                                criterion.__class__.__name__))
    LOGGER.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # 4. Build trainer
    trainer = Trainer(args, task, model, criterion)
    LOGGER.info('training on {} GPUs'.format(
        args['distributed_training']['distributed_world_size']))
    LOGGER.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args['dataset']['max_tokens'],
            args['dataset']['max_sentences'],
        ))

    # 5. Load the latest checkpoint if one is available and restore the corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args,
                                                              trainer,
                                                              combine=False)

    # 6. Train until the learning rate gets too small
    max_epoch = args['optimization']['max_epoch'] or math.inf
    max_update = args['optimization']['max_update'] or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    valid_subsets = args['dataset']['valid_subset'].split(',')
    dev_subsets = args['dataset']['dev_subset'].split(',')
    valid_accs_after_60e = []
    while (lr > args['optimization']['min_lr']
           and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args['dataset']['disable_validation'] and epoch_itr.epoch % args[
                'dataset']['validate_interval'] == 0:
            valid_acc, dev_prf = validate(args,
                                          trainer,
                                          task,
                                          epoch_itr,
                                          valid_subsets,
                                          dev_subsets,
                                          dev_refs=None)
        else:
            valid_acc, dev_prf = None, None

        # if epoch_itr.next_epoch_idx > 61 and valid_acc < valid_accs_after_60e[-1]:
        #     """
        #     We start with a learning rate of 0.5 and start
        #     decaying it by a factor of 0.8 after 60 epochs if
        #     accuracy on the validation set goes down, and
        #     terminate training when the learning rate goes
        #     below 0.001.
        #     """
        #     lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink'])
        #
        # if epoch_itr.epoch >= 60:
        #     valid_accs_after_60e.append(valid_acc)

        # if len(valid_accs_after_60e) > 10 and valid_accs_after_60e[-5] >= valid_acc:
        #     lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink'])
        # valid_accs_after_60e.append(valid_acc)

        if len(valid_accs_after_60e
               ) > 10 and valid_accs_after_60e[-5] >= valid_acc:
            lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink'])

        # eval on dev and dev.ref data

        # save checkpoint
        if epoch_itr.epoch % args['checkpoint']['save_interval'] == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_acc)

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            combine=False,  # TODO to be checked
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in args['task']['data']),
        )

    train_meter.stop()
    LOGGER.info('done training in {:.1f} seconds'.format(train_meter.sum))