Exemplo n.º 1
0
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    tokenizer = MarianTokenizer.from_pretrained(args.model_string)
    tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    pad_id = tokenizer.encode(PAD_TOKEN)[0]
    model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device)
    model.eval()

    checkpoint = torch.load(args.ckpt, map_location=args.device)
    model_args = checkpoint['args']
    conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    conditioning_model.load_state_dict(checkpoint['state_dict'])
    conditioning_model = conditioning_model.to(args.device)
    conditioning_model.eval()
    print("=> loaded checkpoint '{}' (epoch {})"
            .format(args.ckpt, checkpoint['epoch']))
    print('num params', num_params(conditioning_model))

    while True:
        results = predict_formality(model, 
                        tokenizer, 
                        conditioning_model, 
                        [args.input_text], 
                        dataset_info, 
                        precondition_topk=args.precondition_topk,
                        do_sample=args.do_sample,
                        length_cutoff=args.length_cutoff,
                        condition_lambda=args.condition_lambda,
                        device=args.device)
        print(results)
        import pdb; pdb.set_trace()
Exemplo n.º 2
0
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    for cw in args.condition_words.split():
        assert cw in dataset_info.word2index
    gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
    gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
    gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
    gpt_model.eval()

    checkpoint = torch.load(args.ckpt, map_location=args.device)
    model_args = checkpoint['args']
    conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    conditioning_model.load_state_dict(checkpoint['state_dict'])
    conditioning_model = conditioning_model.to(args.device)
    conditioning_model.eval()
    print("=> loaded checkpoint '{}' (epoch {})"
            .format(args.ckpt, checkpoint['epoch']))
    print('num params', num_params(conditioning_model))

    while True:
        results = predict(gpt_model, 
                        gpt_tokenizer, 
                        conditioning_model, 
                        [args.input_text], 
                        args.condition_words, 
                        dataset_info, 
                        args.precondition_topk,
                        args.topk, 
                        args.length_cutoff,
                        condition_lambda=args.condition_lambda,
                        device=args.device)
        print(results)
        import pdb; pdb.set_trace()
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    tokenizer = MarianTokenizer.from_pretrained(args.model_string)
    tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    pad_id = tokenizer.encode(PAD_TOKEN)[0]
    model = MarianMTModel.from_pretrained(args.model_string,
                                          return_dict=True).to(args.device)
    if args.model_path is not None:
        if os.path.isdir(args.model_path):
            for _, _, files in os.walk(args.model_path):
                for fname in files:
                    if fname.endswith('.ckpt'):
                        args.model_path = os.path.join(args.model_path, fname)
                        break
        ckpt = torch.load(args.model_path)
        try:
            model.load_state_dict(ckpt['state_dict'])
        except:
            state_dict = {}
            for key in ckpt['state_dict'].keys():
                assert key.startswith('model.')
                state_dict[key[6:]] = ckpt['state_dict'][key]
            model.load_state_dict(state_dict)
    model.eval()

    checkpoint = torch.load(args.ckpt, map_location=args.device)
    model_args = checkpoint['args']
    conditioning_model = Model(
        model_args, pad_id, len(dataset_info.index2word)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    conditioning_model.load_state_dict(checkpoint['state_dict'])
    conditioning_model = conditioning_model.to(args.device)
    conditioning_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.ckpt, checkpoint['epoch']))
        print('num params', num_params(conditioning_model))

    inputs = []
    with open(args.in_file, 'r') as rf:
        for line in rf:
            inputs.append(line.strip())

    for inp in tqdm(inputs, total=len(inputs)):
        results = predict_formality(model,
                                    tokenizer,
                                    conditioning_model, [inp],
                                    dataset_info,
                                    precondition_topk=args.precondition_topk,
                                    do_sample=args.do_sample,
                                    length_cutoff=args.length_cutoff,
                                    condition_lambda=args.condition_lambda,
                                    device=args.device)
        print(results[0])
Exemplo n.º 4
0
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
    gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
    gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(
        args.device)
    gpt_model.eval()

    checkpoint = torch.load(args.ckpt, map_location=args.device)
    model_args = checkpoint['args']
    model_args.iw = False
    conditioning_model = Model(
        model_args, gpt_pad_id, len(dataset_info.index2word)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    conditioning_model.load_state_dict(checkpoint['state_dict'])
    conditioning_model = conditioning_model.to(args.device)
    conditioning_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.ckpt, checkpoint['epoch']))
        print('num params', num_params(conditioning_model))

    input_texts, conditions, categories = [], [], []

    if args.condition_file is not None:
        with open(args.condition_file, 'r') as rf:
            for line in rf:
                input_texts.append(line.strip().split('\t')[0])
                conditions.append(line.strip().split('\t')[1])
                categories.append(None)
                for cw in conditions[-1].split():
                    assert cw in dataset_info.word2index
    else:
        prefixes = []
        with open(args.prefix_file, 'r') as rf:
            for line in rf:
                prefixes.append(line.strip())
        condition_wordlists = []
        for root, _, files in os.walk(args.wordlist_dir):
            for fname in files:
                words = []
                with open(os.path.join(root, fname), 'r') as rf:
                    for line in rf:
                        word = line.strip()
                        if word in dataset_info.word2index:
                            words.append(word)
                        else:
                            if args.verbose:
                                print('word not found:', word)
                condition_wordlists.append(
                    (' '.join(words), fname.split('.')[0]))
        for p in prefixes:
            for c, category in condition_wordlists:
                input_texts.append(p)
                conditions.append(c)
                categories.append(category)

    all_cr = []
    pair_num = 0
    for input_text, condition_words, category in tqdm(zip(
            input_texts, conditions, categories),
                                                      total=len(conditions)):
        predict_function = predict
        condition_results = []
        for i in range(0, args.sample_size, args.max_sample_batch):
            num_samples = min(args.max_sample_batch, args.sample_size - i)
            condition_results += predict_function(
                gpt_model,
                gpt_tokenizer,
                conditioning_model, [input_text for _ in range(num_samples)],
                condition_words,
                dataset_info,
                args.precondition_topk,
                args.topk,
                args.length_cutoff,
                condition_lambda=args.condition_lambda,
                device=args.device)
        all_cr.append((input_text, category, condition_results))
        pair_num += 1
        if args.max_pairs > 0 and pair_num >= args.max_pairs:
            break
    with open(args.log_file, 'w') as wf:
        writer = csv.DictWriter(
            wf, fieldnames=['category', 'input_text', 'generation'])
        writer.writeheader()
        for cr_group in all_cr:
            for cr in cr_group[2]:
                writer.writerow({
                    'category': cr_group[1],
                    'input_text': cr_group[0],
                    'generation': cr
                })
Exemplo n.º 5
0
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
    gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
    gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(
        args.device)
    gpt_model.eval()

    checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    iambic_model = Model(
        model_args, gpt_pad_id, len(dataset_info.index2word)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    iambic_model.load_state_dict(checkpoint['state_dict'])
    iambic_model = iambic_model.to(args.device)
    iambic_model.eval()
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.iambic_ckpt, checkpoint['epoch']))
    print('iambic model num params', num_params(iambic_model))

    with open(args.rhyme_info, 'rb') as rf:
        rhyme_info = pickle.load(rf)
    checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    rhyme_model = Model(
        model_args,
        gpt_pad_id,
        len(dataset_info.index2word),
        rhyme_group_size=len(rhyme_info.index2rhyme_group)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    rhyme_model.load_state_dict(checkpoint['state_dict'])
    rhyme_model = rhyme_model.to(args.device)
    rhyme_model.eval()
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.rhyme_ckpt, checkpoint['epoch']))
    print('rhyme model num params', num_params(rhyme_model))

    checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    newline_model = Model(
        model_args, gpt_pad_id, len(dataset_info.index2word)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    newline_model.load_state_dict(checkpoint['state_dict'])
    newline_model = newline_model.to(args.device)
    newline_model.eval()
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.newline_ckpt, checkpoint['epoch']))
    print('iambic model num params', num_params(newline_model))

    while True:
        results = predict_couplet(gpt_model,
                                  gpt_tokenizer,
                                  iambic_model,
                                  rhyme_model,
                                  newline_model, [args.input_text],
                                  dataset_info,
                                  rhyme_info,
                                  args.precondition_topk,
                                  args.topk,
                                  condition_lambda=args.condition_lambda,
                                  device=args.device)
        for line in results:
            print(line)
        import pdb
        pdb.set_trace()
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
    gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
    gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(
        args.device)
    gpt_model.eval()

    checkpoint = torch.load(args.ckpt, map_location=args.device)
    model_args = checkpoint['args']
    conditioning_model = Model(
        model_args, gpt_pad_id, len(dataset_info.index2word)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    conditioning_model.load_state_dict(checkpoint['state_dict'])
    conditioning_model = conditioning_model.to(args.device)
    conditioning_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.ckpt, checkpoint['epoch']))
        print('num params', num_params(conditioning_model))

    input_texts, conditions, categories = [], [], []

    if args.condition_file is not None:
        with open(args.condition_file, 'r') as rf:
            lines = rf.read().strip().split('\n')
            lines = lines[int(0.8 * len(lines)):]

            for line in lines:
                input_texts.append(line.strip().split('\t')[0])
                cond_words = line.strip().split('\t')[1].split()
                cond_list = []
                for i, cw in enumerate(cond_words):
                    if cw in dataset_info.word2index:
                        cond_list.append(cw)
                conditions.append(' '.join(cond_list))
                categories.append(None)
    else:
        prefixes = []
        with open(args.prefix_file, 'r') as rf:
            for line in rf:
                prefixes.append(line.strip())
        condition_wordlists = []
        for root, _, files in os.walk(args.wordlist_dir):
            for fname in files:
                words = []
                with open(os.path.join(root, fname), 'r') as rf:
                    for line in rf:
                        word = line.strip()
                        if word in dataset_info.word2index:
                            words.append(word)
                        else:
                            if args.verbose:
                                print('word not found:', word)
                condition_wordlists.append(
                    (' '.join(words), fname.split('.')[0]))
        for p in prefixes:
            for c, category in condition_wordlists:
                input_texts.append(p)
                conditions.append(c)
                categories.append(category)

#    with open(args.log_file, 'w') as wf:
#        writer = csv.DictWriter(wf, fieldnames=['category', 'input_text', 'generation'])
#        writer.writeheader()

#        all_cr = []
    pair_num = 0

    a = datetime.now()
    print('Start time:', a)
    count = 0
    total_words = 0
    for input_text, condition_words, category in tqdm(zip(
            input_texts, conditions, categories),
                                                      total=len(conditions)):
        count += 1
        if count > 10: break

        predict_function = predict
        condition_results = predict_function(
            gpt_model,
            gpt_tokenizer,
            conditioning_model, [input_text],
            condition_words,
            dataset_info,
            args.precondition_topk,
            args.topk,
            args.length_cutoff,
            condition_lambda=args.condition_lambda,
            device=args.device)
        print(condition_results[0])
        total_words += len(condition_results[0].split())

    b = datetime.now()
    dec_time = (b - a).seconds

    avg_dec = dec_time / total_words
    print('Avg decoding time:', str(avg_dec))
Exemplo n.º 7
0
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
    gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
    gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
    gpt_model.eval()

    checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    iambic_model.load_state_dict(checkpoint['state_dict'])
    iambic_model = iambic_model.to(args.device)
    iambic_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.iambic_ckpt, checkpoint['epoch']))
        print('iambic model num params', num_params(iambic_model))

    with open(args.rhyme_info, 'rb') as rf:
        rhyme_info = pickle.load(rf)
    checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group), verbose=args.verbose) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    rhyme_model.load_state_dict(checkpoint['state_dict'])
    rhyme_model = rhyme_model.to(args.device)
    rhyme_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.rhyme_ckpt, checkpoint['epoch']))
        print('rhyme model num params', num_params(rhyme_model))
    
    checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    newline_model.load_state_dict(checkpoint['state_dict'])
    newline_model = newline_model.to(args.device)
    newline_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.newline_ckpt, checkpoint['epoch']))
        print('iambic model num params', num_params(newline_model))

    with open(args.prefix_file, 'r') as rf:
        lines = rf.readlines()
    for line in tqdm(lines, total=len(lines)):
        couplet = predict_couplet(gpt_model, 
                gpt_tokenizer, 
                iambic_model, 
                rhyme_model,
                newline_model,
                [line], 
                dataset_info, 
                rhyme_info,
                args.precondition_topk,
                args.topk, 
                condition_lambda=args.condition_lambda,
                device=args.device)
        assert len(couplet) == 2
        print(couplet[1].strip().replace('\n', ''))
Exemplo n.º 8
0
def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_string)
    tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    pad_id = tokenizer.encode(PAD_TOKEN)[0]
    model = GPT2LMHeadModel.from_pretrained(args.model_string,
                                            return_dict=True).to(args.device)
    if args.model_path is not None:
        if os.path.isdir(args.model_path):
            for _, _, files in os.walk(args.model_path):
                for fname in files:
                    if fname.endswith('.ckpt'):
                        args.model_path = os.path.join(args.model_path, fname)
                        break
        ckpt = torch.load(args.model_path)
        try:
            model.load_state_dict(ckpt['state_dict'])
        except:
            state_dict = {}
            for key in ckpt['state_dict'].keys():
                assert key.startswith('model.')
                state_dict[key[6:]] = ckpt['state_dict'][key]
            model.load_state_dict(state_dict)
    model.eval()

    checkpoint = torch.load(args.ckpt, map_location=args.device)
    model_args = checkpoint['args']
    conditioning_model = Model(
        model_args, pad_id, len(dataset_info.index2word)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    conditioning_model.load_state_dict(checkpoint['state_dict'])
    conditioning_model = conditioning_model.to(args.device)
    conditioning_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.ckpt, checkpoint['epoch']))
        print('num params', num_params(conditioning_model))

    inputs, labels = [], []
    intents = ["inform", "question", "directive", "commissive"]
    with open(args.in_file, 'r') as rf:
        for line in rf:
            items = line.split('\t')
            labels.append(intents.index(items[0].strip()))
            inputs.append(items[1].strip())

    a = datetime.now()
    print('Start time:', a)
    count = 0
    total_words = 0
    for inp, inp_label in zip(inputs, labels):
        count += 1
        if count > 10: break
        results = predict_intent(model,
                                 tokenizer,
                                 conditioning_model, [inp],
                                 inp_label,
                                 dataset_info,
                                 precondition_topk=args.precondition_topk,
                                 do_sample=args.do_sample,
                                 length_cutoff=args.length_cutoff,
                                 condition_lambda=args.condition_lambda,
                                 device=args.device)
        print(results[0])
        total_words += len(results[0].split())

    b = datetime.now()
    dec_time = (b - a).seconds

    avg_dec = dec_time / total_words
    print('Avg decoding time:', str(avg_dec))
        with open(ref_file, 'r') as rf:
            for line in rf:
                ref.append(line.strip())
        assert len(ref) == len(pred)
        refs.append(ref)
    bleu = sacrebleu.corpus_bleu(pred, refs)
    print('BLEU score:', bleu.score)

    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)

    tokenizer = MarianTokenizer.from_pretrained(args.model_string)
    tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    pad_id = tokenizer.encode(PAD_TOKEN)[0]

    checkpoint = torch.load(args.ckpt, map_location=args.device)
    model_args = checkpoint['args']
    conditioning_model = Model(
        model_args, pad_id, len(dataset_info.index2word)
    )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    conditioning_model.load_state_dict(checkpoint['state_dict'])
    conditioning_model = conditioning_model.to(args.device)
    conditioning_model.eval()
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.ckpt, checkpoint['epoch']))
    print('num params', num_params(conditioning_model))

    print(
        'avg formality prob according to model',
        avg_formality(pred, conditioning_model, tokenizer, device=args.device))
def main(args):
    dataset = Dataset(args)
    os.makedirs(args.save_dir, exist_ok=True)
    with open(os.path.join(args.save_dir, 'dataset_info'), 'wb') as wf:
        pickle.dump(dataset.dataset_info, wf)
    if args.task == 'rhyme':
        with open(os.path.join(args.save_dir, 'rhyme_info'), 'wb') as wf:
            pickle.dump(dataset.rhyme_info, wf)
    if args.ckpt:
        checkpoint = torch.load(args.ckpt, map_location=args.device)
        start_epoch = checkpoint['epoch'] + 1
        best_val_metric = checkpoint['best_metric']
        model_args = checkpoint['args']
        model = Model(
            model_args,
            dataset.gpt_pad_id,
            len(dataset.index2word),
            rhyme_group_size=len(dataset.index2rhyme_group)
            if args.task == 'rhyme' else None
        )  # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
        model.load_state_dict(checkpoint['state_dict'])
        model = model.to(args.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=model_args.lr)
        optimizer.load_state_dict(checkpoint['optimizer'])
        data_start_index = checkpoint['data_start_index']
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.ckpt, checkpoint['epoch']))
        # NOTE: just import pdb after loading the model here if you want to play with it, it's easy
        # model.eval()
        # import pdb; pdb.set_trace()
    else:
        model = Model(args,
                      dataset.gpt_pad_id,
                      len(dataset.index2word),
                      rhyme_group_size=len(dataset.index2rhyme_group)
                      if args.task == 'rhyme' else None,
                      glove_embeddings=dataset.glove_embeddings)
        model = model.to(args.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        best_val_metric = 1e8  # lower is better for BCE
        data_start_index = 0
    print('num params', num_params(model))
    criterion = nn.BCEWithLogitsLoss().to(args.device)

    if args.evaluate:
        epoch = 0
        validate(model, dataset, criterion, epoch, args)
        return
    for epoch in range(args.epochs):
        print("TRAINING: Epoch {} at {}".format(epoch, time.ctime()))
        data_start_index = train(model, dataset, optimizer, criterion, epoch,
                                 args, data_start_index)
        if epoch % args.validation_freq == 0:
            print("VALIDATION: Epoch {} at {}".format(epoch, time.ctime()))
            metric = validate(model, dataset, criterion, epoch, args)

            if not args.debug:
                if metric < best_val_metric:
                    print('new best val metric', metric)
                    best_val_metric = metric
                    save_checkpoint(
                        {
                            'epoch': epoch,
                            'state_dict': model.state_dict(),
                            'best_metric': best_val_metric,
                            'optimizer': optimizer.state_dict(),
                            'data_start_index': data_start_index,
                            'args': args
                        }, os.path.join(args.save_dir, 'model_best.pth.tar'))
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'best_metric': metric,
                        'optimizer': optimizer.state_dict(),
                        'data_start_index': data_start_index,
                        'args': args
                    },
                    os.path.join(args.save_dir,
                                 'model_epoch' + str(epoch) + '.pth.tar'))