示例#1
0
def evaluate(data_source, batch_size=10):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    if args.model == 'QRNN':
        model.reset()
    total_loss = 0
    ntokens = ds.ntokens
    hidden = model.init_hidden(batch_size)
    for i in range(0, data_source.size(0) - 1, args.bptt):
        data, targets = ds.get_batch(data_source, i)
        targets = targets.view(-1)
        output, hidden = model(data, hidden)
        total_loss += len(data) * criterion(
            model.decoder.weight, model.decoder.bias, output, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss.item() / len(data_source)
示例#2
0
def evaluate(data_source, batch_size=10):
    model.eval()
    total_loss = 0
    hidden = model.init_hidden(batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, args.bptt):
            data, targets = ds.get_batch(data_source, i)
            targets = targets.view(-1)

            output, hidden = model(data, hidden)
            loss = criterion(output.view(-1, output.size(2)), targets).data
            total_loss += len(data) * loss

            hidden = repackage_hidden(hidden)

    return total_loss.item() / len(data_source)
示例#3
0
def evaluate_scores(epoch, batch_size):
    model.eval()
    if args.model == 'QRNN':
        model.reset()
    total_loss = 0
    ntokens = ds.ntokens
    hidden = model.init_hidden(batch_size)
    for data, targets in ds.train_seq():
        targets = targets.view(-1).contiguous()
        output, hidden = model(data, hidden)
        loss = criterion(model.decoder.weight, model.decoder.bias, output,
                         targets).data
        sk.add_prior_sample(epoch, loss.item())
        total_loss += len(data) * loss
        hidden = repackage_hidden(hidden)
    sk.save_prior_epoch()
    return total_loss.item() / ds.data_size
示例#4
0
def evaluate(data_source, batch_size=10):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    hidden = model.init_hidden(batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, args.bptt):
            data, targets = ds.get_batch(data_source, i)
            targets = targets.view(-1)

            log_prob, hidden = parallel_model(data, hidden)
            loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)),
                                          targets).data

            total_loss += len(data) * loss
            hidden = repackage_hidden(hidden)

    return total_loss.item() / len(data_source)
示例#5
0
def train():
    global tot_steps
    batch = 0
    total_loss = 0
    model.train()
    hidden = model.init_hidden(args.batch_size)
    start_time = time.time()
    for data, targets, id_ in dh.train_seq():
        data = data.permute(1, 0, 2)
        if data.shape[1] != args.batch_size:
            continue
        if torch.cuda.is_available():
            data = data.cuda()
            targets = targets.cuda()
        targets = targets.view(-1)
        model.zero_grad()
        hidden = repackage_hidden(hidden)
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, output.size(2)), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        total_loss += loss

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time
            # ppl = math.exp(cur_loss)
            lr = optimizer.param_groups[0]['lr']
            ms_batch = elapsed * 1000 / args.log_interval
            logger.info(
                '| epoch {:3d} | '.format(epoch) +
                'lr {:02.2f} | ms/batch {:5.2f} | '.format(lr, ms_batch) +
                'loss {:5.2f}'.format(cur_loss))
            save_tb(tb, "train/loss", tot_steps, cur_loss)
            total_loss = 0
            start_time = time.time()

        batch += 1
        tot_steps += 1
示例#6
0
def train():
    global tot_steps
    batch = 0
    total_loss = 0
    model.train()
    hidden = model.init_hidden(args.batch_size)
    start_time = time.time()
    for data, targets in ds.train_seq():
        targets = targets.view(-1)
        model.zero_grad()
        hidden = repackage_hidden(hidden)
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, output.size(2)), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        total_loss += loss.data

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time
            ppl = math.exp(cur_loss)
            lr = optimizer.param_groups[0]['lr']
            ms_batch = elapsed * 1000 / args.log_interval
            logger.info(f'| epoch {epoch:3d} | '
                        f'{batch:5d}/{ds.nbatch:5d} batches | '
                        f'lr {lr:02.2f} | ms/batch {ms_batch:5.2f} | '
                        f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            save_tb(tb, "train/loss", tot_steps, cur_loss)
            save_tb(tb, "train/ppl", tot_steps, ppl)
            total_loss = 0
            start_time = time.time()

        batch += 1
        tot_steps += 1

        if tot_steps >= args.max_steps:
            logger.info(f"Reached max-steps at tot step {tot_steps}, breaking "
                        "the train function")
            break
示例#7
0
def train(epoch):
    global tot_steps
    # Turn on training mode which enables dropout.
    if args.model == 'QRNN':
        model.reset()
    total_loss = 0
    start_time = time.time()
    ntokens = ds.ntokens
    hidden = model.init_hidden(ds.batch_size)
    batch, i = 0, 0

    if (epoch % args.grad_interval == 0 or epoch == 1) and \
            (args.save_grad or args.save_gradPure):
        embed.requires_grad = True
        save_grad, save_gradPure = args.save_grad, args.save_gradPure
    else:
        if embed and embed.requires_grad:
            embed.requires_grad = False
        save_grad, save_gradPure = False, False

    for data, targets in train_seq():
        # shape of data is (bptt, batch_size)
        targets = targets.view(-1).contiguous()
        seq_len = args.bptt
        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
        model.train()

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()

        output, hidden, rnn_hs, dropped_rnn_hs = model(data,
                                                       hidden,
                                                       return_h=True)
        raw_loss = criterion(model.decoder.weight, model.decoder.bias, output,
                             targets)

        loss = raw_loss
        # Activiation Regularization
        if args.alpha:
            loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean()
                              for dropped_rnn_h in dropped_rnn_hs[-1:])
        # Temporal Activation Regularization (slowness)
        if args.beta:
            loss = loss + \
                sum(args.beta * (rnn_h[1:] - rnn_h[:-1]
                                 ).pow(2).mean() for rnn_h in rnn_hs[-1:])
        loss.backward(retain_graph=args.save_gradPure)
        sk.add_sample(epoch, i, loss.item())

        if save_grad:
            with torch.no_grad():
                # shape(btpp, batch_size, voc_size)
                grad = embed.last_oh.grad
                # shape(bptt, batch_size, 1, embed_size)
                res = torch.stack([
                    torch.stack([
                        torch.mm(
                            grad[token_i, batch_i].view(1, -1), 1 /
                            (embed.last_weight * args.emsize +
                             sys.float_info.epsilon))
                        for batch_i in range(args.batch_size)
                    ],
                                dim=0) for token_i in range(args.bptt)
                ],
                                  dim=0)
                assert list(
                    res.shape) == [args.bptt, args.batch_size, 1, args.emsize]
                sk.add_data("grad", epoch, i,
                            res.detach().cpu().numpy().tolist())

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        if args.clip:
            torch.nn.utils.clip_grad_norm_(params, args.clip)
        optimizer.step()

        total_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2
        if i % args.log_interval == 0 and i > 0:
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time
            ppl = math.exp(cur_loss)
            bpc = cur_loss / math.log(2)
            logger.info(
                '| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                    epoch, i, ds.nbatch, optimizer.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, cur_loss, ppl, bpc))
            save_tb(tb, "train/loss", tot_steps, cur_loss)
            save_tb(tb, "train/ppl", tot_steps, ppl)
            total_loss = 0
            start_time = time.time()
        ###

        ###
        if save_gradPure:
            optimizer.zero_grad()
            embed.last_oh.grad.zero_()
            output.sum().backward()
            with torch.no_grad():
                # shape(btpp, batch_size, voc_size)
                grad = embed.last_oh.grad
                res = torch.stack([
                    torch.stack([
                        torch.mm(
                            grad[token_i, batch_i].view(1, -1), 1 /
                            (embed.last_weight * args.emsize +
                             sys.float_info.epsilon))
                        for batch_i in range(args.batch_size)
                    ],
                                dim=0) for token_i in range(args.bptt)
                ],
                                  dim=0)
                sk.add_data("gradPure", epoch, i,
                            res.detach().cpu().numpy().tolist())

        tot_steps += 1
        i += 1

        if tot_steps in args.when_steps:
            logger.info(f'(Step {tot_steps}) Saving model before learning '
                        'rate decreased')
            model_save('{}.e{}'.format("model.pt", epoch))
            logger.info('Dividing learning rate by 10')
            optimizer.param_groups[0]['lr'] /= 10.

        if tot_steps >= args.max_steps:
            logger.info(f"Reached max-steps at tot step {tot_steps}, breaking "
                        "the train function")
            break
示例#8
0
def train():
    global tot_steps
    assert args.batch_size % args.small_batch_size == 0, \
        'batch_size must be divisible by small_batch_size'

    # Turn on training mode which enables dropout.
    total_loss = 0
    start_time = time.time()
    hidden = [
        model.init_hidden(args.small_batch_size)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    batch = 0
    for data, targets in train_seq():
        seq_len = args.bptt

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
        model.train()

        optimizer.zero_grad()

        start, end, s_id = 0, args.small_batch_size, 0
        while start < args.batch_size:
            cur_data, cur_targets = data[:,
                                         start:end], targets[:, start:
                                                             end].contiguous(
                                                             ).view(-1)

            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            hidden[s_id] = repackage_hidden(hidden[s_id])

            log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs = parallel_model(
                cur_data, hidden[s_id], return_h=True)
            raw_loss = nn.functional.nll_loss(
                log_prob.view(-1, log_prob.size(2)), cur_targets)

            loss = raw_loss
            # Activiation Regularization
            loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean()
                              for dropped_rnn_h in dropped_rnn_hs[-1:])
            # Temporal Activation Regularization (slowness)
            loss = loss + \
                sum(args.beta * (rnn_h[1:] - rnn_h[:-1]
                                 ).pow(2).mean() for rnn_h in rnn_hs[-1:])
            loss *= args.small_batch_size / args.batch_size
            total_loss += raw_loss.data * args.small_batch_size / args.batch_size
            loss.backward()

            s_id += 1
            start = end
            end = start + args.small_batch_size

            gc.collect()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        # total_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time
            ppl = math.exp(cur_loss)
            logger.info(
                '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(ds.current_seq),
                    optimizer.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, cur_loss, ppl))
            save_tb(tb, "train/loss", tot_steps, cur_loss)
            save_tb(tb, "train/ppl", tot_steps, ppl)
            total_loss = 0
            start_time = time.time()
        ###
        batch += 1
        tot_steps += 1