예제 #1
0
파일: train.py 프로젝트: SCUZPP/ENAS
def train_model(args, encoder_seq, encoder_para, decoder_seq, decoder_para):

    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    args.ddp_backend = 'no_c10d'
    args.encoder_seq = encoder_seq
    args.encoder_para = encoder_para

    args.decoder_seq = decoder_seq
    args.decoder_para = decoder_para

    #print('args.arch', args.arch)
    #print('max-tokens', args.max_tokens)
    if args.distributed_port > 0 or args.distributed_init_method is not None:
        #print('0')
        from distributed_train import main as distributed_main

        distributed_main(args)
    elif args.distributed_world_size > 1:
        #print('1')
        from multiprocessing_train import main as multiprocessing_main

        multiprocessing_main(args)
    else:
        #print('2')
        main(args)
예제 #2
0
def cli_main():
    # Horrible hack, please close your eyes and don't look
    cli_args = set(sys.argv)
    print("Command line argumetns")
    print(cli_args)
    if "--arch" not in cli_args and "-a" not in cli_args:
        sys.argv.append("--arch")
        sys.argv.append("transformer_iwslt_de_en")
    print(cli_args)
    # It's over now you can look
    parser = options.get_training_parser()
    add_multiobj_args(parser)
    args = options.parse_args_and_arch(parser)

    if args.distributed_port > 0 or args.distributed_init_method is not None:
        raise NotImplementedError(
            "Multitask doesn't support multiprocessing yet")
        from distributed_train import main as distributed_main

        distributed_main(args)
    elif args.distributed_world_size > 1:
        raise NotImplementedError(
            "Multitask doesn't support multiprocessing yet")
        from multiprocessing_train import main as multiprocessing_main

        multiprocessing_main(args)
    else:
        main(args)
예제 #3
0
def main(args):
    if args.distributed_port > 0 \
            or args.distributed_init_method is not None:
        distributed_main(args)
    elif args.distributed_world_size > 1:
        multiprocessing_main(args)
    else:
        singleprocess_main(args)
예제 #4
0
파일: train.py 프로젝트: sk210892/fairseq

def load_dataset_splits(args, task, splits):
    for split in splits:
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            try:
                task.load_dataset(split_k)
                print('| {} {} {} examples'.format(args.data, split_k,
                                                   len(task.dataset(split_k))))
            except FileNotFoundError as e:
                if k > 0:
                    break
                raise e


if __name__ == '__main__':
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)

    if args.distributed_port > 0 or args.distributed_init_method is not None:
        from distributed_train import main as distributed_main

        distributed_main(args)
    elif args.distributed_world_size > 1:
        from multiprocessing_train import main as multiprocessing_main

        multiprocessing_main(args)
    else:
        main(args)
예제 #5
0
def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""
    os.makedirs(os.path.join(args.save_dir, 'checkpoints'), exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, 'checkpoints',
                                   args.restore_file)
    if os.path.isfile(checkpoint_path):
        extra_state = trainer.load_checkpoint(checkpoint_path)
        if extra_state is not None:
            # replay train iterator to match checkpoint
            epoch_itr.load_state_dict(extra_state['train_iterator'])

            print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
                checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))

            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
            if 'best' in extra_state:
                save_checkpoint.best = extra_state['best']


if __name__ == '__main__':
    parser = options.get_training_parser()
    ARGS = options.parse_args_and_arch(parser)

    if ARGS.distributed_port > 0 or ARGS.distributed_init_method is not None:
        from distributed_train import main as distributed_main

        distributed_main(ARGS)
    else:
        main(ARGS)