예제 #1
0
    parser.add_argument('--load',
                        type=str,
                        required=True,
                        help='where to load a model from')
    args = parser.parse_args()
    print(args)

    init_seeds(args.seed, args.cuda)

    print("loading model...")
    lm = torch.load(args.load)
    if args.cuda:
        lm.cuda()
    print(lm.model)

    print("preparing data...")

    def temp_splits_from_fn(fn):
        tokens = tokens_from_file(fn, lm.vocab, randomize=False)
        return TemporalSplits(tokens, lm.model.in_len, args.target_seq_len)

    tss = filelist_to_objects(args.file_list, temp_splits_from_fn)
    data = BatchBuilder(tss,
                        args.batch_size,
                        discard_h=not args.concat_articles)
    if args.cuda:
        data = CudaStream(data)

    loss = evaluate(lm, data, use_ivecs=False)
    print('loss {:5.2f} | ppl {:8.2f}'.format(loss, math.exp(loss)))
예제 #2
0
        lm.cuda()
    print(lm.model)

    print("preparing data...")

    def temp_splits_from_fn(fn):
        tokens = tokens_from_file(fn, lm.vocab, randomize=False)
        return TemporalSplits(tokens, lm.model.in_len, args.target_seq_len)

    print("\ttraining...")
    train_tss = filelist_to_objects(args.train_list, temp_splits_from_fn)
    train_data = BatchBuilder(train_tss,
                              args.batch_size,
                              discard_h=not args.concat_articles)
    if args.cuda:
        train_data = CudaStream(train_data)

    print("\tvalidation...")
    valid_tss = filelist_to_objects(args.valid_list, temp_splits_from_fn)
    valid_data = BatchBuilder(valid_tss,
                              args.batch_size,
                              discard_h=not args.concat_articles)
    if args.cuda:
        valid_data = CudaStream(valid_data)

    print("training...")
    lr = args.lr
    best_val_loss = None

    for epoch in range(1, args.epochs + 1):
        if args.keep_shuffling: