def _train_epoch(model, epoch, tqdm_data, kl_weight, optimizer=None):
    if optimizer is None:
        model.eval()
    else:
        model.train()

    kl_loss_values = mosesvocab.CircularBuffer(1000)
    recon_loss_values = mosesvocab.CircularBuffer(1000)
    loss_values = mosesvocab.CircularBuffer(1000)
    for i, input_batch in enumerate(tqdm_data):
        input_batch = tuple(data.cuda() for data in input_batch)

        # Forwardd
        kl_loss, recon_loss, _ = model(input_batch)
        kl_loss = torch.sum(kl_loss, 0)
        recon_loss = torch.sum(recon_loss, 0)

        loss = kl_weight * kl_loss + recon_loss

        # Backward
        if optimizer is not None:
            optimizer.zero_grad()
            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #     scaled_loss.backward()
            loss.backward()
            clip_grad_norm_((p for p in model.parameters() if p.requires_grad),
                            50)
            optimizer.step()

        # Log
        kl_loss_values.add(kl_loss.item())
        recon_loss_values.add(recon_loss.item())
        loss_values.add(loss.item())
        lr = (optimizer.param_groups[0]['lr']
              if optimizer is not None else None)

        # Update tqdm
        kl_loss_value = kl_loss_values.mean()
        recon_loss_value = recon_loss_values.mean()
        loss_value = loss_values.mean()
        postfix = [
            f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}',
            f'recon={recon_loss_value:.5f})',
            f'klw={kl_weight:.5f} lr={lr:.5f}'
        ]
        tqdm_data.set_postfix_str(' '.join(postfix))

    postfix = {
        'epoch': epoch,
        'kl_weight': kl_weight,
        'lr': lr,
        'kl_loss': kl_loss_value,
        'recon_loss': recon_loss_value,
        'loss': loss_value,
        'mode': 'Eval' if optimizer is None else 'Train'
    }

    return postfix
def _train_epoch_binding(model, epoch, tqdm_data, optimizer=None):
    if optimizer is None:
        model.eval()
    else:
        model.train()

    loss_values = mosesvocab.CircularBuffer(10)
    binding_loss_values = mosesvocab.CircularBuffer(10)
    for i, (input_batch, binding) in enumerate(tqdm_data):
        input_batch = tuple(data.cuda() for data in input_batch)
        binding = binding.cuda().view(-1, 1)
        # Forwardd
        _, _, binding_loss, _ = model(input_batch, binding)

        binding_loss = torch.sum(binding_loss, 0)

        loss = binding_loss

        # Backward
        if optimizer is not None:
            optimizer.zero_grad()

            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #     scaled_loss.backward()
            loss.backward()
            clip_grad_norm_((p for p in model.parameters() if p.requires_grad),
                            50)
            optimizer.step()

        # Log
        loss_values.add(loss.item())
        binding_loss_values.add(binding_loss.item())
        lr = (optimizer.param_groups[0]['lr']
              if optimizer is not None else None)

        # Update tqdm
        loss_value = loss_values.mean()
        binding_loss_value = binding_loss_values.mean()
        postfix = [f'loss={loss_value:.5f}', f'bloss={binding_loss_value:.5f}']
        tqdm_data.set_postfix_str(' '.join(postfix))

    postfix = {
        'epoch': epoch,
        'lr': lr,
        'loss': loss_value,
        'mode': 'Eval' if optimizer is None else 'Train'
    }

    return postfix
Beispiel #3
0
def _train_epoch_binding(model, epoch, tqdm_data, kl_weight, encoder_optim,
                         decoder_optim):
    model.train()

    kl_loss_values = mosesvocab.CircularBuffer(10)
    recon_loss_values = mosesvocab.CircularBuffer(10)
    loss_values = mosesvocab.CircularBuffer(10)
    for i, (input_batch, _) in enumerate(tqdm_data):

        if epoch < 10:
            if i % 1 == 0:
                for (input_batch_, _) in train_loader_agg_tqdm:
                    encoder_optimizer.zero_grad()
                    decoder_optimizer.zero_grad()
                    input_batch_ = tuple(data.cuda() for data in input_batch_)
                    # Forwardd
                    kl_loss, recon_loss, _, logvar, x, y = model(input_batch_)
                    kl_loss = torch.sum(kl_loss, 0)
                    recon_loss = torch.sum(recon_loss, 0)
                    _, predict = torch.max(F.softmax(y, dim=-1), -1)

                    correct = float(
                        (x == predict).sum().cpu().detach().item()) / float(
                            x.shape[0] * x.shape[1])
                    kl_weight = min(kl_weight * 1e-1 + 1e-3, 1)
                    # kl_weight = 1
                    loss = kl_weight * kl_loss + recon_loss
                    # loss = kl_loss + recon_loss
                    loss.backward()
                    clip_grad_norm_(
                        (p for p in model.parameters() if p.requires_grad), 25)
                    encoder_optimizer.step()
                    loss_value = loss.item()
                    kl_loss_value = kl_loss.item()
                    recon_loss_value = recon_loss.item()

                    postfix = [
                        f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}',
                        f'recon={recon_loss_value:.5f})',
                        f'correct={correct:.5f}'
                    ]
                    train_loader_agg_tqdm.set_postfix_str(' '.join(postfix))

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        input_batch = tuple(data.cuda() for data in input_batch)
        # Forwardd
        kl_loss, recon_loss, _, logvar, x, y = model(input_batch)
        _, predict = torch.max(F.softmax(y, dim=-1), -1)

        correct = float((x == predict).sum().cpu().detach().item()) / float(
            x.shape[0] * x.shape[1])

        kl_loss = torch.sum(kl_loss, 0)
        recon_loss = torch.sum(recon_loss, 0)

        # kl_weight =  min(kl_weight + 1e-3,1)
        kl_weight = 1
        loss = recon_loss
        # loss = kl_loss + recon_loss

        loss.backward()
        clip_grad_norm_((p for p in model.parameters() if p.requires_grad), 50)

        if epoch >= 10:
            loss += kl_weight * kl_loss
            encoder_optimizer.step()
        decoder_optimizer.step()

        # Log
        kl_loss_values.add(kl_loss.item())
        recon_loss_values.add(recon_loss.item())
        loss_values.add(loss.item())
        lr = (encoder_optim.param_groups[0]['lr']
              if encoder_optim is not None else None)

        # Update tqdm
        kl_loss_value = kl_loss_values.mean()
        recon_loss_value = recon_loss_values.mean()
        loss_value = loss_values.mean()
        postfix = [
            f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}',
            f'recon={recon_loss_value:.5f})',
            f'klw={kl_weight:.5f} lr={lr:.5f}'
            f'correct={correct:.5f}'
        ]
        tqdm_data.set_postfix_str(' '.join(postfix))

    postfix = {
        'epoch': epoch,
        'kl_weight': kl_weight,
        'lr': lr,
        'kl_loss': kl_loss_value,
        'recon_loss': recon_loss_value,
        'loss': loss_value
    }

    return postfix
Beispiel #4
0
def _train_epoch_binding(model, epoch, tqdm_data, kl_weight, optimizer=None):
    model.eval()
    if optimizer is None:
        bindingmodel.eval()
    else:
        bindingmodel.train()

    kl_loss_values = mosesvocab.CircularBuffer(10)
    recon_loss_values = mosesvocab.CircularBuffer(10)
    loss_values = mosesvocab.CircularBuffer(10)
    binding_loss_values = mosesvocab.CircularBuffer(10)
    for i, (input_batch, binding) in enumerate(tqdm_data):
        input_batch = tuple(data.cuda() for data in input_batch)
        binding = binding.cuda().view(-1, 1)
        # Forwardd
        kl_loss, recon_loss, _, z = model(input_batch, binding)
        b_pred = bindingmodel(z.detach())

        weights = torch.zeros(binding.shape)
        class_weights = torch.zeros(binding.shape)
        for i in range(binding.shape[0]):
            if binding[i] > 0.35:
                weights[i] = 1.0
                class_weights[i] = 5.0
            else:
                weights[i] = 0
                class_weights[i] = 0.5
        weights = weights.cuda()
        class_weights = class_weights.cuda()
        b_pred = F.binary_cross_entropy_with_logits(b_pred,
                                                    weights,
                                                    pos_weight=class_weights)

        kl_loss = torch.sum(kl_loss, 0)
        recon_loss = torch.sum(recon_loss, 0)
        binding_loss = torch.sum(b_pred, 0)

        loss_weight = 0
        if epoch < 5:
            loss_weight = 0
        else:
            loss_weight = kl_weight

        loss = binding_loss

        # Backward
        if optimizer is not None:
            optimizer.zero_grad()

            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #     scaled_loss.backward()
            loss.backward()
            clip_grad_norm_(bindingmodel.parameters(), 50)
            optimizer.step()

        # Log
        kl_loss_values.add(kl_loss.item())
        recon_loss_values.add(recon_loss.item())
        loss_values.add(loss.item())
        binding_loss_values.add(binding_loss.item())
        lr = (optimizer.param_groups[0]['lr']
              if optimizer is not None else None)

        # Update tqdm
        kl_loss_value = kl_loss_values.mean()
        recon_loss_value = recon_loss_values.mean()
        loss_value = loss_values.mean()
        binding_loss_value = binding_loss_values.mean()
        postfix = [
            f'loss={loss_value:.5f}', f'(kl={kl_loss_value:.5f}',
            f'recon={recon_loss_value:.5f})',
            f'klw={kl_weight:.5f} lr={lr:.5f}'
            f'bloss={binding_loss_value:.5f}'
        ]
        tqdm_data.set_postfix_str(' '.join(postfix))

    postfix = {
        'epoch': epoch,
        'kl_weight': kl_weight,
        'lr': lr,
        'kl_loss': kl_loss_value,
        'recon_loss': recon_loss_value,
        'loss': loss_value,
        'mode': 'Eval' if optimizer is None else 'Train'
    }

    return postfix