Esempio n. 1
0
def test(model, device, epoch, ema, data_loader, tag, root_process):
    # convert model to evaluation mode (no Dropout etc.)
    model.eval()

    # setup the reconstruction dataset
    recon_dataset = None
    nbatches = data_loader.batch_sampler.sampler.num_samples // data_loader.batch_size
    recon_batch_idx = int(torch.Tensor(1).random_(0, nbatches - 1))

    # setup testing metrics
    if root_process:
        logrecons = torch.zeros((nbatches), device=device)
        logdecs = torch.zeros((nbatches, model.nz), device=device)
        logencs = torch.zeros((nbatches, model.nz), device=device)

    elbos = []

    # switch to EMA parameters for evaluation
    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data = ema.get_ema(name)

    # allocate memory for the input data
    data = torch.zeros((data_loader.batch_size,) + model.xs, device=device)

    # enumerate over the batches
    for batch_idx, (batch, _) in enumerate(data_loader):
        # save batch for reconstruction
        if batch_idx == recon_batch_idx:
            recon_dataset = data

        # copy the mini-batch in the pre-allocated data-variable
        data.copy_(batch)

        with torch.no_grad():
            # evaluate the data under the model and calculate ELBO components
            logrecon, logdec, logenc, _ = model.loss(data)

            # construct the ELBO
            elbo = -logrecon + torch.sum(-logdec + logenc)

            # compute the inference- and generative-model loss
            logdec = torch.sum(logdec, dim=1)
            logenc = torch.sum(logenc, dim=1)

        if root_process:
            # scale by image dimensions to get "bits/dim"
            elbo *= model.perdimsscale
            logrecon *= model.perdimsscale
            logdec *= model.perdimsscale
            logenc *= model.perdimsscale

            elbos.append(elbo.item())

            # log
            logrecons[batch_idx] += logrecon
            logdecs[batch_idx] += logdec
            logencs[batch_idx] += logenc

    if root_process:
        elbo = np.mean(elbos)

        entrecon = -torch.mean(logrecons).detach().cpu().numpy()
        entdec = -torch.mean(logdecs, dim=0).detach().cpu().numpy()
        entenc = -torch.mean(logencs, dim=0).detach().cpu().numpy()
        kl = entdec - entenc

        # print metrics to console and Tensorboard
        print(f'\nEpoch: {epoch}\tTest loss: {elbo:.6f}')
        model.logger.add_scalar('elbo/test', elbo, epoch)

        # log to Tensorboard
        model.logger.add_scalar('x/reconstruction/test', entrecon, epoch)
        for i in range(1, logdec.shape[0] + 1):
            model.logger.add_scalar(f'z{i}/encoder/test', entenc[i - 1], epoch)
            model.logger.add_scalar(f'z{i}/decoder/test', entdec[i - 1], epoch)
            model.logger.add_scalar(f'z{i}/KL/test', kl[i - 1], epoch)

        # if the current ELBO is better than the ELBO's before, save parameters
        if elbo < model.best_elbo and not np.isnan(elbo):
            model.logger.add_scalar('elbo/besttest', elbo, epoch)
            if not os.path.exists(f'params/mnist/'):
                os.makedirs(f'params/mnist/')
            torch.save(model.state_dict(), f'params/mnist/{tag}')
            if epoch % 25 == 0:
                torch.save(model.state_dict(), f'params/mnist/epoch{epoch}_{tag}')
            print("saved params\n")
            model.best_elbo = elbo

            model.sample(device, epoch)
            model.reconstruct(recon_dataset, device, epoch)
        else:
            print("loss did not improve\n")
Esempio n. 2
0
def train(model, device, epoch, data_loader, optimizer, ema, log_interval, root_process, schedule=True, decay=0.99995):
    # convert model to train mode (activate Dropout etc.)
    model.train()

    # get number of batches
    nbatches = data_loader.batch_sampler.sampler.num_samples // data_loader.batch_size

    # switch to parameters not affected by exponential moving average decay
    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data = ema.get_default(name)

    # setup training metrics
    if root_process:
        elbos = torch.zeros((nbatches), device=device)
        logrecons = torch.zeros((nbatches), device=device)
        logdecs = torch.zeros((nbatches, model.nz), device=device)
        logencs = torch.zeros((nbatches, model.nz), device=device)

    if root_process:
        start_time = time.time()

    # allocate memory for data
    data = torch.zeros((data_loader.batch_size,) + model.xs, device=device)

    # enumerate over the batches
    for batch_idx, (batch, _) in enumerate(data_loader):
        # keep track of the global step
        global_step = (epoch - 1) * len(data_loader) + (batch_idx + 1)

        # update the learning rate according to schedule
        if schedule:
            for param_group in optimizer.param_groups:
                lr = param_group['lr']
                lr = lr_step(global_step, lr, decay=decay)
                param_group['lr'] = lr

        # empty all the gradients stored
        optimizer.zero_grad()

        # copy the mini-batch in the pre-allocated data-variable
        data.copy_(batch)

        # evaluate the data under the model and calculate ELBO components
        logrecon, logdec, logenc, zsamples = model.loss(data)

        # free bits technique, in order to prevent posterior collapse
        bits_pc = 1.
        kl = torch.sum(torch.max(-logdec + logenc, bits_pc * torch.ones((model.nz, model.zdim[0]), device=device)))

        # compute the inference- and generative-model loss
        logdec = torch.sum(logdec, dim=1)
        logenc = torch.sum(logenc, dim=1)

        # construct ELBO
        elbo = -logrecon + kl

        # scale by image dimensions to get "bits/dim"
        elbo *= model.perdimsscale
        logrecon *= model.perdimsscale
        logdec *= model.perdimsscale
        logenc *= model.perdimsscale

        # calculate gradients
        elbo.backward()

        # take gradient step
        total_norm = nn.utils.clip_grad_norm_(model.parameters(), 1., norm_type=2)
        optimizer.step()

        # log gradient norm
        if root_process:
            model.logger.add_scalar('gnorm', total_norm, global_step)

        # do ema update on parameters used for evaluation
        if root_process:
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        ema(name, param.data)

        # log
        if root_process:
            elbos[batch_idx] += elbo
            logrecons[batch_idx] += logrecon
            logdecs[batch_idx] += logdec
            logencs[batch_idx] += logenc

        # log and save parameters
        if root_process and batch_idx % log_interval == 0 and log_interval < nbatches:
            # print metrics to console
            print(f'Train Epoch: {epoch} [{batch_idx}/{nbatches} ({100. * batch_idx / len(data_loader):.0f}%)]\tLoss: {elbo.item():.6f}\tGnorm: {total_norm:.2f}\tSteps/sec: {(time.time() - start_time) / (batch_idx + 1):.3f}')


            model.logger.add_scalar('step-sec', (time.time() - start_time) / (batch_idx + 1), global_step)
            entrecon = -logrecon
            entdec = -logdec
            entenc = -logenc
            kl = entdec - entenc

            # log
            model.logger.add_scalar('elbo/train', elbo, global_step)
            for param_group in optimizer.param_groups:
                lr = param_group['lr']
            model.logger.add_scalar('lr', lr, global_step)

            model.logger.add_scalar('x/reconstruction/train', entrecon, global_step)
            for i in range(1, logdec.shape[0] + 1):
                model.logger.add_scalar(f'z{i}/encoder/train', entenc[i - 1], global_step)
                model.logger.add_scalar(f'z{i}/decoder/train', entdec[i - 1], global_step)
                model.logger.add_scalar(f'z{i}/KL/train', kl[i - 1], global_step)

    # save training params, to be able to return to these values after evaluation
    with torch.no_grad():
        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.register_default(name, param.data)

    # print the average loss of the epoch to the console
    if root_process:
        elbo = torch.mean(elbos).detach().cpu().numpy()
        print(f'====> Epoch: {epoch} Average loss: {elbo:.4f}')