Esempio n. 1
0
def get_args():
    """ Defines training-specific hyper-parameters. """
    parser = argparse.ArgumentParser('Sequence to Sequence Model')

    # Add data arguments
    parser.add_argument('--data', default='europarl_prepared', help='path to data directory')
    parser.add_argument('--source-lang', default='de', help='source language')
    parser.add_argument('--target-lang', default='en', help='target language')
    parser.add_argument('--max-tokens', default=None, type=int, help='maximum number of tokens in a batch')
    parser.add_argument('--batch-size', default=10, type=int, help='maximum number of sentences in a batch')
    parser.add_argument('--train-on-tiny', action='store_true', help='train model on a tiny dataset')

    # Add model arguments
    parser.add_argument('--device', default='cpu', choices=['cpu', 'cuda'], help='the device to carry out the training '
                                                                                'either cpu or cuda aka gpu')
    parser.add_argument('--arch', default='lstm', choices=ARCH_MODEL_REGISTRY.keys(), help='model architecture')

    # Add optimization arguments
    parser.add_argument('--max-epoch', default=100, type=int, help='force stop training at specified epoch')
    parser.add_argument('--clip-norm', default=4.0, type=float, help='clip threshold of gradients')
    parser.add_argument('--lr', default=0.0003, type=float, help='learning rate')
    parser.add_argument('--patience', default=10, type=int,
                        help='number of epochs without improvement on validation set before early stopping')

    # Add checkpoint arguments
    parser.add_argument('--log-file', default=None, help='path to save logs')
    parser.add_argument('--save-dir', default='checkpoints', help='path to save checkpoints')
    parser.add_argument('--restore-file', default='checkpoint_last.pt', help='filename to load checkpoint')
    parser.add_argument('--save-interval', type=int, default=1, help='save a checkpoint every N epochs')
    parser.add_argument('--no-save', action='store_true', help='don\'t save models or checkpoints')
    parser.add_argument('--epoch-checkpoints', action='store_true', help='store all epoch checkpoints')

    # Parse twice as model arguments are not known the first time
    args, _ = parser.parse_known_args()
    model_parser = parser.add_argument_group(argument_default=argparse.SUPPRESS)
    ARCH_MODEL_REGISTRY[args.arch].add_args(model_parser)
    args = parser.parse_args()
    ARCH_CONFIG_REGISTRY[args.arch](args)
    return args
Esempio n. 2
0
def get_args():
    parser = argparse.ArgumentParser('Sequence to Sequence Model')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        help='pseudo random number generator seed')
    parser.add_argument('--distributed-world-size',
                        default=torch.cuda.device_count(),
                        help='distributed world size')
    parser.add_argument('--distributed-backend',
                        default='nccl',
                        help='distributed backend')

    # Add data arguments
    parser.add_argument('--data',
                        default='data-bin',
                        help='path to data directory')
    parser.add_argument('--source-lang', default=None, help='source language')
    parser.add_argument('--target-lang', default=None, help='target language')
    parser.add_argument('--max-tokens',
                        default=16000,
                        type=int,
                        help='maximum number of tokens in a batch')
    parser.add_argument('--batch-size',
                        default=None,
                        type=int,
                        help='maximum number of sentences in a batch')
    parser.add_argument('--num-workers',
                        default=4,
                        type=int,
                        help='number of data workers')

    # Add model arguments
    parser.add_argument('--arch',
                        default='lstm',
                        choices=ARCH_MODEL_REGISTRY.keys(),
                        help='model architecture')

    # Add optimization arguments
    parser.add_argument('--max-epoch',
                        default=100,
                        type=int,
                        help='force stop training at specified epoch')
    parser.add_argument('--clip-norm',
                        default=0.1,
                        type=float,
                        help='clip threshold of gradients')
    parser.add_argument('--lr', default=0.25, type=float, help='learning rate')
    parser.add_argument('--momentum',
                        default=0.99,
                        type=float,
                        help='momentum factor')
    parser.add_argument('--weight-decay',
                        default=0.0,
                        type=float,
                        help='weight decay')
    parser.add_argument('--lr-shrink',
                        default=0.1,
                        type=float,
                        help='learning rate shrink factor for annealing')
    parser.add_argument('--min-lr',
                        default=1e-5,
                        type=float,
                        help='minimum learning rate')

    # Add checkpoint arguments
    parser.add_argument('--log-file', default=None, help='path to save logs')
    parser.add_argument('--save-dir',
                        default='checkpoints',
                        help='path to save checkpoints')
    parser.add_argument('--restore-file',
                        default='checkpoint_last.pt',
                        help='filename to load checkpoint')
    parser.add_argument('--save-interval',
                        type=int,
                        default=1,
                        help='save a checkpoint every N epochs')
    parser.add_argument('--no-save',
                        action='store_true',
                        help='don\'t save models or checkpoints')
    parser.add_argument('--epoch-checkpoints',
                        action='store_true',
                        help='store all epoch checkpoints')

    return parser.parse_args()
Esempio n. 3
0
def get_args():
    parser = argparse.ArgumentParser('Sequence to Sequence Model')
    parser.add_argument('--seed',
                        default=42,
                        type=int,
                        help='pseudo random number generator seed')
    parser.add_argument('--distributed-world-size',
                        default=torch.cuda.device_count(),
                        help='distributed world size')
    parser.add_argument('--distributed-backend',
                        default='nccl',
                        help='distributed backend')

    # Add data arguments
    parser.add_argument('--data',
                        default='data-bin/how2',
                        help='path to data directory')
    parser.add_argument('--max-tokens',
                        default=16000,
                        type=int,
                        help='maximum number of tokens in a batch')
    parser.add_argument('--train_video_file',
                        default='../how2data/text/sum_train/tr_action.txt',
                        help='name of train video file')
    parser.add_argument('--val_video_file',
                        default='../how2data/text/sum_cv/cv_action.txt',
                        help='name of val video file')
    parser.add_argument('--video_dir',
                        default='../how2data/video_action_features',
                        help='path of video features')
    parser.add_argument('--batch-size',
                        default=8,
                        type=int,
                        help='maximum number of sentences in a batch')
    parser.add_argument('--num-workers',
                        default=2,
                        type=int,
                        help='number of data workers')

    # Add model arguments
    parser.add_argument('--arch',
                        default='MFtransformer',
                        choices=ARCH_MODEL_REGISTRY.keys(),
                        help='model architecture')

    # Add optimization arguments
    parser.add_argument('--max-epoch',
                        default=10,
                        type=int,
                        help='force stop training at specified epoch')
    parser.add_argument('--clip-norm',
                        default=0.1,
                        type=float,
                        help='clip threshold of gradients')
    parser.add_argument('--lr',
                        default=0.00015,
                        type=float,
                        help='learning rate')
    parser.add_argument('--momentum',
                        default=0.99,
                        type=float,
                        help='momentum factor')
    parser.add_argument('--weight-decay',
                        default=0.0,
                        type=float,
                        help='weight decay')
    parser.add_argument('--lr-shrink',
                        default=0.1,
                        type=float,
                        help='learning rate shrink factor for annealing')
    parser.add_argument('--min-lr',
                        default=1e-5,
                        type=float,
                        help='minimum learning rate')

    # Add checkpoint arguments
    parser.add_argument('--log-file',
                        default='./log.txt',
                        help='path to save logs')
    parser.add_argument('--save-dir',
                        default='checkpoints/mfn',
                        help='path to save checkpoints')
    parser.add_argument('--restore-file',
                        default='checkpoint_last.pt',
                        help='filename to load checkpoint')
    parser.add_argument('--save-interval',
                        type=int,
                        default=1,
                        help='save a checkpoint every N epochs')
    parser.add_argument('--no-save',
                        action='store_true',
                        help='don\'t save models or checkpoints')
    parser.add_argument('--epoch-checkpoints',
                        action='store_true',
                        help='store all epoch checkpoints')

    # Parse twice as model arguments are not known the first time
    args, _ = parser.parse_known_args()
    model_parser = parser.add_argument_group(
        argument_default=argparse.SUPPRESS)
    ARCH_MODEL_REGISTRY[args.arch].add_args(model_parser)
    args = parser.parse_args()
    ARCH_CONFIG_REGISTRY[args.arch](args)
    return args