Example #1
0
def train(model, train_loader, epoch, save_dir, optimizer, style_bank):
    from numpy import random
    datalen = len(train_loader)
    t0 = time.time()
    for i, inputs in enumerate(train_loader):
        if i == datalen - 1:
            break
        bank = random.randint(120)
        style = style_bank[bank]
        prev_state1 = None
        prev_state2 = None
        contents = inputs
        frame_i = []
        frame_o = []
        for t in range(10):
            frame_i.append(contents[:, t, :, :, :])
        loss = 0
        for t1 in range(9):
            loss_c, loss_s, loss_t, out, prev_state1, prev_state2 = model(
                frame_i[t1].to(device), frame_i[t1 + 1].to(device),
                style.to(device), prev_state1, prev_state2, bank)
            prev_state1 = repackage_hidden(prev_state1)
            prev_state2 = repackage_hidden(prev_state2)

            frame_o.append(out.detach())
            loss_c = loss_c * args.content_weight
            loss_s = loss_s * args.style_weight
            loss_t = loss_t * args.short_weight
            loss = (loss_c + loss_s + loss_t)
            for t2 in range(1, t1 - 4):
                loss_t = model.temporal_loss(frame_i[t2].to(device),
                                             frame_i[t1 + 1].to(device),
                                             frame_o[t2 - 1].to(device),
                                             out.to(device))
                loss_t = loss_t * args.long_weight / (t1 - 5)
                loss += loss_t
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        t2 = time.time()

        # if (i + 1) % args.print_freq == 0:
        if True:
            logger.info(
                'Epoch [%d] Iter: [%d/%d] LR:%f Time: %.3f Loss: %.5f LossContent: %.5f  LossStyle: %.5f LossTemporal: %.5f'
                % (epoch, i + 1, datalen, args.lr, t2 - t0,
                   loss.data.cpu().item(), loss_c.data.cpu().item(),
                   loss_s.data.cpu().item(), loss_t.data.cpu().item()))
            t0 = t2
Example #2
0
 def init_hidden(self, batch_shape, hidden=None):
     weight = next(self.parameters()).data
     if not isinstance(batch_shape, Iterable):
         batch_shape = [batch_shape]
     bsz = reduce(lambda x, y: x * y, batch_shape, 1)
     if hidden is not None and hidden.shape[1] == bsz:
         return utils.repackage_hidden(hidden)
     else:
         return weight.new(1, bsz, self.rnn_units).zero_()
    def forward(self, content1, content2, style, prev_state1, prev_state2, bank):
        with torch.no_grad():
            contents = self.concat(content1, content2)
            flowout = self.flownet(contents)
            mask = self.mask_occlusion(content1, self.warp(content2, flowout))
            g_t1, return_state1, return_state2 = self.styler(vgg_norm(content1), prev_state1, prev_state2, bank)

        g_t2, prev_state1, prev_state2 = self.styler(vgg_norm(content2), repackage_hidden(return_state1), repackage_hidden(return_state2), bank)

        content_feat = self.vgg(vgg_norm(Variable(content2.data, requires_grad=False)))[2]
        style_feats = self.vgg(vgg_norm(style))
        output_feats = self.vgg(vgg_norm(g_t2))

        loss_c = self.calc_content_loss(output_feats[2], Variable(content_feat.data, requires_grad=False))
        loss_s = 0
        for i in range(4):
            loss_s += self.calc_style_loss(output_feats[i], style_feats[i].data)
        loss_t = self.calc_temporal_loss(g_t1, g_t2, flowout, mask)

        return loss_c, loss_s, loss_t, g_t1, return_state1, return_state2
Example #4
0
        os.mkdir(args.log_dir)

    content_tf = train_transform()
    style_tf = train_transform()

    print('loading dataset done', flush=True)

    # style_bank = styleInput()
    from utils.utils import repackage_hidden
    print(styler, flush=True)
    avg = []
    for bank in range(120):
        prev_state1 = None
        prev_state2 = None
        for i in range(1, 10):
            path = '%05d.jpg'%(i)
            cimg = Image.open(os.path.join('/home/gaowei/IJCAI/videvo/videvo/test/WaterFall2/', path)).convert('RGB')

            cimg = content_tf(cimg).unsqueeze(0).cuda()
            cimg = vgg_norm(cimg)
            with torch.no_grad():
                out, prev_state1, prev_state2 = styler(cimg, prev_state1, prev_state2, bank)


            prev_state1 = repackage_hidden(prev_state1)
            prev_state2 = repackage_hidden(prev_state2)
            save_image(out, 'output/%06d.jpg'%(i-1 + bank * 49))
    # mmcv.frames2video('output', 'mst_cat_flow.avi', fps=6)


Example #5
0
def main(params):

    dp = DataProvider(params)

    # Create vocabulary and author index
    if params['resume'] == None:
        if params['atoms'] == 'char':
            char_to_ix, ix_to_char = dp.createCharVocab(
                params['vocab_threshold'])
        else:
            char_to_ix, ix_to_char = dp.createWordVocab(
                params['vocab_threshold'])
        auth_to_ix, ix_to_auth = dp.createAuthorIdx()
    else:
        saved_model = torch.load(params['resume'])
        char_to_ix = saved_model['char_to_ix']
        auth_to_ix = saved_model['auth_to_ix']
        ix_to_auth = saved_model['ix_to_auth']
        ix_to_char = saved_model['ix_to_char']

    params['vocabulary_size'] = len(char_to_ix)
    params['num_output_layers'] = len(auth_to_ix)

    model = CharTranslator(params)
    # set to train mode, this activates dropout
    model.train()
    #Initialize the RMSprop optimizer

    if params['use_sgd']:
        optim = torch.optim.SGD(model.parameters(),
                                lr=params['learning_rate'],
                                momentum=params['decay_rate'])
    else:
        optim = torch.optim.RMSprop(model.parameters(),
                                    lr=params['learning_rate'],
                                    alpha=params['decay_rate'],
                                    eps=params['smooth_eps'])
    # Loss function
    if params['mode'] == 'generative':
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.NLLLoss()

    # Restore saved checkpoint
    if params['resume'] != None:
        model.load_state_dict(saved_model['state_dict'])
        optim.load_state_dict(saved_model['optimizer'])

    total_loss = 0.
    start_time = time.time()
    hidden = model.init_hidden(params['batch_size'])
    hidden_zeros = model.init_hidden(params['batch_size'])
    # Initialize the cache
    if params['randomize_batches']:
        dp.set_hid_cache(range(len(dp.data['docs'])), hidden_zeros)

    # Compute the iteration parameters
    epochs = params['max_epochs']
    total_seqs = dp.get_num_sents(split='train')
    iter_per_epoch = total_seqs // params['batch_size']
    total_iters = iter_per_epoch * epochs
    best_loss = 1000000.
    best_val = 1000.
    eval_every = int(iter_per_epoch * params['eval_interval'])

    #val_score = eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval'])
    val_score = 0.  #eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval'])
    val_rank = 1000

    eval_function = eval_translator if params[
        'mode'] == 'generative' else eval_classify
    leakage = 0.  #params['leakage']

    print total_iters
    for i in xrange(total_iters):
        #TODO
        if params['split_generators']:
            c_aid = ix_to_auth[np.random.choice(auth_to_ix.values())]
        else:
            c_aid = None

        batch = dp.get_sentence_batch(params['batch_size'],
                                      split='train',
                                      atoms=params['atoms'],
                                      aid=c_aid,
                                      sample_by_len=params['sample_by_len'])
        inps, targs, auths, lens = dp.prepare_data(
            batch, char_to_ix, auth_to_ix, maxlen=params['max_seq_len'])
        # Reset the hidden states for which new docs have been sampled

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optim.zero_grad()
        #TODO
        if params['mode'] == 'generative':
            output, _ = model.forward_mltrain(inps,
                                              lens,
                                              inps,
                                              lens,
                                              hidden_zeros,
                                              auths=auths)
            targets = pack_padded_sequence(Variable(targs).cuda(), lens)
            loss = criterion(pack_padded_sequence(output, lens)[0], targets[0])
        else:
            # for classifier auths is the target
            output, hidden = model.forward_classify(inps,
                                                    hidden,
                                                    compute_softmax=True)
            targets = Variable(auths).cuda()
            loss = criterion(output, targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), params['grad_clip'])

        # Take an optimization step
        optim.step()

        total_loss += loss.data.cpu().numpy()[0]

        # Save the hidden states in cache for later use
        if i % eval_every == 0 and i > 0:
            val_rank, val_score = eval_function(dp,
                                                model,
                                                params,
                                                char_to_ix,
                                                auth_to_ix,
                                                split='val')

        #if i % iter_per_epoch == 0 and i > 0 and leakage > params['leakage_min']:
        #    leakage = leakage * params['leakage_decay']

        #if (i % iter_per_epoch == 0) and ((i//iter_per_epoch) >= params['lr_decay_st']):
        if i % params['log_interval'] == 0 and i > 0:
            cur_loss = total_loss / params['log_interval']
            elapsed = time.time() - start_time
            print(
                '| epoch {:2.2f} | {:5d}/{:5d} batches | lr {:02.2e} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f}'.format(
                    float(i) / iter_per_epoch, i, total_iters,
                    params['learning_rate'],
                    elapsed * 1000 / args.log_interval, cur_loss,
                    math.exp(cur_loss)))
            total_loss = 0.

            if val_rank <= best_val:
                save_checkpoint(
                    {
                        'iter': i,
                        'arch': params,
                        'val_loss': val_rank,
                        'val_pplx': val_score,
                        'char_to_ix': char_to_ix,
                        'ix_to_char': ix_to_char,
                        'auth_to_ix': auth_to_ix,
                        'ix_to_auth': ix_to_auth,
                        'state_dict': model.state_dict(),
                        'loss': cur_loss,
                        'optimizer': optim.state_dict(),
                    },
                    fappend=params['fappend'],
                    outdir=params['checkpoint_output_directory'])
                best_val = val_rank
            start_time = time.time()
Example #6
0
def main(params):
    dp = DataProvider(params)

    # Create vocabulary and author index
    if params['resume'] == None:
        if params['atoms'] == 'char':
            char_to_ix, ix_to_char = dp.create_char_vocab(
                params['vocab_threshold'])
        else:
            char_to_ix, ix_to_char = dp.create_word_vocab(
                params['vocab_threshold'])
        auth_to_ix, ix_to_auth = dp.create_author_idx()
    else:
        saved_model = torch.load(params['resume'])
        char_to_ix = saved_model['char_to_ix']
        auth_to_ix = saved_model['auth_to_ix']
        ix_to_char = saved_model['ix_to_char']

    params['vocabulary_size'] = len(char_to_ix)
    params['num_output_layers'] = len(auth_to_ix)
    print
    params['vocabulary_size'], params['num_output_layers']

    model = get_classifier(params)
    # set to train mode, this activates dropout
    model.train()
    # Initialize the RMSprop optimizer

    if params['use_sgd']:
        optim = torch.optim.SGD(model.parameters(),
                                lr=params['learning_rate'],
                                momentum=params['decay_rate'])
    else:
        optim = torch.optim.RMSprop([{
            'params':
            [p[1] for p in model.named_parameters() if p[0] != 'decoder_W']
        }, {
            'params': model.decoder_W,
            'weight_decay': 0.000
        }],
                                    lr=params['learning_rate'],
                                    alpha=params['decay_rate'],
                                    eps=params['smooth_eps'])
    # Loss function
    if len(params['balance_loss']) == 0:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.CrossEntropyLoss(
            torch.FloatTensor(params['balance_loss']).cuda())

    # Restore saved checkpoint
    if params['resume'] != None:
        model.load_state_dict(saved_model['state_dict'])
        # optim.load_state_dict(saved_model['optimizer'])

    total_loss = 0.
    class_loss = 0.
    start_time = time.time()
    hidden = model.init_hidden(params['batch_size'])
    hidden_zeros = model.init_hidden(params['batch_size'])
    # Initialize the cache
    if params['randomize_batches']:
        dp.set_hid_cache(range(len(dp.data['docs'])), hidden_zeros)

    # Compute the iteration parameters
    epochs = params['max_epochs']
    total_seqs = dp.get_num_sents(split='train')
    iter_per_epoch = total_seqs // params['batch_size']
    total_iters = iter_per_epoch * epochs
    best_loss = 0.
    best_val = 1000.
    eval_every = int(iter_per_epoch * params['eval_interval'])

    # val_score = eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval'])
    val_score = 0.  # eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval'])
    val_rank = 0

    eval_function = eval_model if params[
        'mode'] == 'generative' else eval_classify

    leakage = params['leakage']
    for i in xrange(total_iters):
        # TODO
        if params['randomize_batches']:
            batch, reset_next = dp.get_rand_doc_batch(params['batch_size'],
                                                      split='train')
            b_ids = [b['id'] for b in batch]
            hidden = dp.get_hid_cache(b_ids, hidden)
        elif params['use_sentences']:
            c_aid = None  # ix_to_auth[np.random.choice(auth_to_ix.values())]
            batch = dp.get_sentence_batch(
                params['batch_size'],
                split='train',
                aid=c_aid,
                atoms=params['atoms'],
                sample_by_len=params['sample_by_len'])
            hidden = hidden_zeros
        else:
            batch, reset_h = dp.get_doc_batch(split='train')
            if len(reset_h) > 0:
                hidden[0].data.index_fill_(1,
                                           torch.LongTensor(reset_h).cuda(),
                                           0.)
                hidden[1].data.index_fill_(1,
                                           torch.LongTensor(reset_h).cuda(),
                                           0.)

        inps, targs, auths, lens = dp.prepare_data(batch,
                                                   char_to_ix,
                                                   auth_to_ix,
                                                   leakage=leakage)

        # Reset the hidden states for which new docs have been sampled

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optim.zero_grad()

        # TODO
        if params['mode'] == 'generative':
            output, hidden = model.forward(inps, lens, hidden, auths)
            targets = pack_padded_sequence(Variable(targs).cuda(), lens)
            loss = criterion(pack_padded_sequence(output, lens)[0], targets[0])
        else:
            # for classifier auths is the target
            output, _ = model.forward_classify(targs,
                                               hidden,
                                               compute_softmax=False,
                                               lens=lens)
            targets = Variable(auths).cuda()
            lossClass = criterion(output, targets)
            if params['compression_layer']:
                loss = lossClass + (model.compression_W.weight.norm(
                    p=1, dim=1)).mean()
            else:
                loss = lossClass
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), params['grad_clip'])

        # Take an optimization step
        optim.step()

        total_loss += loss.data.cpu().numpy()[0]
        class_loss += lossClass.data.cpu().numpy()[0]

        # Save the hidden states in cache for later use
        if params['randomize_batches']:
            if len(reset_next) > 0:
                hidden[0].data.index_fill_(1,
                                           torch.LongTensor(reset_next).cuda(),
                                           0.)
                hidden[1].data.index_fill_(1,
                                           torch.LongTensor(reset_next).cuda(),
                                           0.)
            dp.set_hid_cache(b_ids, hidden)

        if i % eval_every == 0 and i > 0:
            val_rank, val_score = eval_function(dp,
                                                model,
                                                params,
                                                char_to_ix,
                                                auth_to_ix,
                                                split='val',
                                                max_docs=params['num_eval'])

        if i % iter_per_epoch == 0 and i > 0 and leakage > params[
                'leakage_min']:
            leakage = leakage * params['leakage_decay']

        # if (i % iter_per_epoch == 0) and ((i//iter_per_epoch) >= params['lr_decay_st']):
        if i % params['log_interval'] == 0 and i > 0:
            cur_loss = total_loss / params['log_interval']
            class_loss = class_loss / params['log_interval']
            elapsed = time.time() - start_time
            print(
                '| epoch {:3.2f} | {:5d}/{:5d} batches | lr {:02.2e} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f}'.format(
                    float(i) / iter_per_epoch, i, total_iters,
                    params['learning_rate'],
                    elapsed * 1000 / args.log_interval, cur_loss,
                    math.exp(class_loss)))

            if val_rank >= best_loss:
                best_loss = val_rank
                save_checkpoint(
                    {
                        'iter': i,
                        'arch': params,
                        'val_mean_rank': val_rank,
                        'val_auc': val_score,
                        'char_to_ix': char_to_ix,
                        'ix_to_char': ix_to_char,
                        'auth_to_ix': auth_to_ix,
                        'state_dict': model.state_dict(),
                        'loss': cur_loss,
                        'optimizer': optim.state_dict(),
                    },
                    fappend=params['fappend'],
                    outdir=params['checkpoint_output_directory'])
                best_val = val_rank
            start_time = time.time()
            total_loss = 0.
            class_loss = 0.
Example #7
0
    def evaluate(self, data, eos_tokens=None, dump_hiddens=False):

        # get weights and compute WX for all words
        #weights_ih, bias_ih = self.rnn.module.weight_ih_l0, self.rnn.module.bias_ih_l0  # only one layer for the moment
        #weights_hh, bias_hh = self.rnn.module.weight_hh_l0, self.rnn.module.bias_hh_l0

        all_words = torch.LongTensor([i for i in range(self.ntoken)]).cuda()
        all_words = embedded_dropout(
            self.encoder,
            all_words,
            dropout=self.dropoute if self.training else 0).view(
                1, self.ntoken, -1)

        # iterate over data set and compute loss
        total_loss, hidden = 0, self.init_hidden(1)
        i = 0

        entropy, hiddens, all_hiddens = [], [], []
        while i < data.size(0):

            #all_words_times_W = self.rnn.module._in_times_W(all_words, hidden)

            #hidden_times_U = torch.nn.functional.linear(hidden[0].repeat(self.ntoken, 1), weights_hh, bias_hh)
            #print(all_words.size(), hidden[0].repeat(1,self.ntoken,1)[0].size())
            output = self.rnn(
                all_words, hidden[0].repeat(1, self.ntoken, 1)
            )[0]  #self._forward(all_words_times_W, hidden_times_U, hidden[0].repeat(self.ntoken, 1))

            if dump_hiddens:
                pass  #hiddens.append(output[data[i]].data.cpu().numpy())

            distance = self.dist_fn(hidden[0], output[0])
            #print(output.size(), distance.size(), hidden)
            if not self.threshold is None:
                distance = self._apply_threshold(distance, hidden[0])
            distance = self._apply_temperature(distance)
            distance = self._apply_bias(distance, self.bias)

            softmaxed = torch.nn.functional.log_softmax(-distance, dim=0)
            raw_loss = -softmaxed[data[i]].item()

            total_loss += raw_loss / data.size(0)
            entropy.append(raw_loss)

            if not eos_tokens is None and data[i].data.cpu().numpy(
            )[0] in eos_tokens:
                hidden = self.init_hidden(1)
                hidden = hidden.detach()
                if dump_hiddens:
                    pass  #all_hiddens.append(hiddens)
                    hiddens = []
            else:
                hidden = output[0][data[i]].view(1, 1, -1)
            hidden = repackage_hidden(hidden).detach()

            i = i + 1

        all_hiddens = all_hiddens if not eos_tokens is None else hiddens

        if self.threshold_decr > 0:
            self.threshold_max_r = max(self.threshold_min_r,
                                       self.threshold_max_r * 0.95)

        if dump_hiddens:
            return total_loss, np.array(entropy), all_hiddens
        else:
            return total_loss, np.array(entropy)
Example #8
0
    def train():

        # Turn on training mode which enables dropout.
        total_loss, avrg_loss = 0, 0
        start_time = time.time()
        ntokens = len(corpus.dictionary)
        batch, i = 0, 0

        # need a hidden state for tl and one for mos
        h_tl = tl_model.init_hidden(args.batch_size)
        h_mos = mos_model.init_hidden(args.batch_size)
        data_keep = 5 * torch.ones(1, args.batch_size).cuda().long()
        while i < train_data.size(0) - 1:

            # get seq len from batch
            seq_len = seq_lens[batch] - 1

            # adapt learning rate
            lr2 = optimizer.param_groups[0]['lr']
            optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
            tl_model.train()

            # get data and binary map
            data = get_batch(train_data, i, args, seq_len=seq_len)
            binary = get_batch(binary_data, i, args, seq_len=seq_len)

            # evaluate mos on data
            h_mos = repackage_hidden(h_mos)
            mos_data = torch.cat((data_keep, data), 0)[:-1]
            mos_data[mos_data >= 10000] = 0  # ugly fix!!!!
            log_prob, h_mos = mos_model(mos_data, h_mos)

            #print(torch.exp(log_prob))
            #print(data, mos_data)

            # get probability ranks from mos probability
            _, argsort = torch.sort(log_prob, descending=True)
            argsort = argsort.view(-1, 10000)

            #print(argsort.size())

            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            h_tl = tl_model.init_hidden(args.batch_size)
            h_tl = repackage_hidden(h_tl)

            optimizer.zero_grad()

            #raw_loss = model.train_crossentropy(data, eos_tokens)
            raw_loss = tl_model(data, binary, h_tl, argsort)
            avrg_loss = avrg_loss + (seq_len +
                                     1) * raw_loss.data / train_data.size(0)

            loss = raw_loss
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip)
            optimizer.step()

            total_loss += loss.data
            optimizer.param_groups[0]['lr'] = lr2
            if batch % args.log_interval == 0:
                cur_loss = total_loss.item() / args.log_interval
                elapsed = time.time() - start_time
                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                        epoch, batch,
                        len(train_data) // args.bptt,
                        optimizer.param_groups[0]['lr'],
                        elapsed * 1000 / args.log_interval, cur_loss,
                        np.exp(cur_loss), cur_loss / math.log(2)))
                total_loss = 0
                start_time = time.time()
            ###
            batch += 1
            i += seq_len + 1

            #break

        return avrg_loss  #/ train_data.size(0)
Example #9
0
    def evaluate(data_source, epoch, batch_size=1):
        # Turn on evaluation mode which disables dropout.
        tl_model.eval()

        total_loss, i = 0, 0
        h_tl = tl_model.init_hidden(batch_size)
        h_mos = mos_model.init_hidden(batch_size)

        data_keep = 5 * torch.ones(1, 1).cuda().long()
        while i < data_source.size(0) - 1:

            seq_len = args.bptt  # one to many (data has size 2, need this because targets!)
            data = get_batch(data_source, i, args, seq_len=seq_len)
            mos_data = torch.cat(
                (data_keep, data),
                0)[:-1]  # no need to include eos token in mos_data
            data_keep = data[-1].view(1, 1)

            # evaluate mos for probability ranks
            h_mos = repackage_hidden(h_mos)
            log_prob, h_mos = mos_model(mos_data, h_mos)

            #print(torch.exp(log_prob), data.view(-1), mos_data.view(-1))

            # cut log_probs to actual sequence length
            #log_prob = log_prob[:-1]

            #print(data, mos_data)
            # get probability ranks from mos probability
            #print(torch.exp(log_prob), data[1:])
            _, argsort = torch.sort(log_prob, descending=True)
            argsort = argsort.view(-1, 10000)  #

            # evaluate tl model
            h_tl = repackage_hidden(h_tl)
            loss, h_tl, entropy = tl_model.evaluate(data, h_tl, argsort,
                                                    eos_tokens)

            total_loss = total_loss + loss * min(seq_len, data_source.size(0))
            i = i + seq_len + 1

        total_loss = total_loss / data_source.size(0)
        '''     

        if args.dump_hiddens:
            loss, entropy, hiddens = tl_model.evaluate(data_source, eos_tokens, args.dump_hiddens)
            dump_hiddens(hiddens, 'hiddens_' + str(epoch))
        else:
            loss, entropy = tl_model.evaluate(data_source, eos_tokens)

        #loss = loss.item()
        if args.dump_words:
            W = tl_model.rnn.module.weight_ih_l0.detach()
            dump_words(torch.nn.functional.linear(tl_model.encoder.weight.detach(), W).detach().cpu().numpy(), 'words_xW_' + str(epoch))
            dump_words(tl_model.encoder.weight.detach().cpu().numpy(), 'words_' + str(epoch))

        if not args.dump_entropy is None:
            dump(entropy, args.dump_entropy + str(epoch))

        '''
        return total_loss