def execute_graph(epoch, model, loader, grapher, optimizer=None, prefix='test'):
    """ execute the graph; when 'train' is in the name the model runs the optimizer

    :param epoch: the current epoch number
    :param model: the torch model
    :param loader: the train or **TEST** loader
    :param grapher: the graph writing helper (eg: visdom / tf wrapper)
    :param optimizer: the optimizer
    :param prefix: 'train', 'test' or 'valid'
    :returns: dictionary with scalars
    :rtype: dict

    """
    start_time = time.time()
    is_eval = 'train' not in prefix
    model.eval() if is_eval else model.train()
    assert optimizer is None if is_eval else optimizer is not None
    loss_map, num_samples = {}, 0

    # iterate over data and labels
    for num_minibatches, (minibatch, labels) in enumerate(loader):
        minibatch = minibatch.cuda(non_blocking=True) if args.cuda else minibatch
        labels = labels.cuda(non_blocking=True) if args.cuda else labels

        with torch.no_grad() if is_eval else utils.dummy_context():
            if is_eval and args.polyak_ema > 0:                                  # use the Polyak model for predictions
                pred_logits = layers.get_polyak_prediction(
                    model, pred_fn=functools.partial(model, minibatch))
            else:
                pred_logits = model(minibatch)                                   # get normal predictions

            acc1, acc5 = metrics.topk(output=pred_logits, target=labels, topk=(1, 5))
            loss_t = {
                'loss_mean': F.cross_entropy(input=pred_logits, target=labels),  # change to F.mse_loss for regression
                'top1_mean': acc1,
                'top5_mean': acc5,
            }
            loss_map = loss_t if not loss_map else tree.map_structure(           # aggregate loss
                _extract_sum_scalars, loss_map, loss_t)
            num_samples += minibatch.size(0)                                     # count minibatch samples

        if not is_eval:                                                          # compute bp and optimize
            optimizer.zero_grad()                                                # zero gradients on optimizer
            if args.half:
                with amp.scale_loss(loss_t['loss_mean'], optimizer) as scaled_loss:
                    scaled_loss.backward()                                       # compute grads (fp16+fp32)
            else:
                loss_t['loss_mean'].backward()                                   # compute grads (fp32)

            if args.clip > 0:
                # TODO: clip by value or norm? torch.nn.utils.clip_grad_value_
                # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                nn.utils.clip_grad_value_(model.parameters(), args.clip)

            optimizer.step()                                                     # update the parameters
            if args.polyak_ema > 0:                                              # update Polyak EMA if requested
                layers.polyak_ema_parameters(model, args.polyak_ema)

            del loss_t

        if args.debug_step:  # for testing purposes
            break

    # compute the mean of the dict
    loss_map = tree.map_structure(
        lambda v: v / (num_minibatches + 1), loss_map)                          # reduce the map to get actual means

    # log some stuff
    def tensor2item(t): return t.detach().item() if isinstance(t, torch.Tensor) else t
    to_log = '{}-{}[Epoch {}][{} samples][{:.2f} sec]:\t Loss: {:.4f}\tTop-1: {:.4f}\tTop-5: {:.4f}'
    print(to_log.format(
        prefix, args.distributed_rank, epoch, num_samples, time.time() - start_time,
        tensor2item(loss_map['loss_mean']),
        tensor2item(loss_map['top1_mean']),
        tensor2item(loss_map['top5_mean'])))

    # plot the test accuracy, loss and images
    register_plots({**loss_map}, grapher, epoch=epoch, prefix=prefix)

    # tack on images to grapher, reducing size of image and the total samples for bandwidth
    num_images_to_post = min(64, minibatch.shape[0])
    image_size_to_post = min(64, minibatch.shape[-1])
    images_to_post = F.interpolate(minibatch[0:num_images_to_post],
                                   size=(image_size_to_post, image_size_to_post))
    image_map = {'input_imgs': images_to_post}
    register_images({**image_map}, grapher, prefix=prefix)
    if grapher is not None:
        grapher.save()

    # cleanups (see https://tinyurl.com/ycjre67m) + return loss for early stopping
    loss_val = tensor2item(loss_map['loss_mean'])
    loss_map.clear()
    return loss_val
Exemple #2
0
def execute_graph(epoch,
                  model,
                  loader,
                  grapher,
                  optimizer=None,
                  prefix='test'):
    """ execute the graph; when 'train' is in the name the model runs the optimizer

    :param epoch: the current epoch number
    :param model: the torch model
    :param loader: the train or **TEST** loader
    :param grapher: the graph writing helper (eg: visdom / tf wrapper)
    :param optimizer: the optimizer
    :param prefix: 'train', 'test' or 'valid'
    :returns: dictionary with scalars
    :rtype: dict

    """
    start_time = time.time()
    model.eval() if prefix == 'test' else model.train()
    assert optimizer is not None if 'train' in prefix or 'valid' in prefix else optimizer is None
    loss_map, num_samples, print_once = {}, 0, False

    # iterate over train and valid data
    for minibatch, labels in loader:
        minibatch = minibatch.cuda() if args.cuda else minibatch
        labels = labels.cuda() if args.cuda else labels
        if args.half:
            minibatch = minibatch.half()

        if 'train' in prefix:
            optimizer.zero_grad()  # zero gradients on optimizer

        with torch.no_grad() if prefix == 'test' else dummy_context():
            pred_logits = model(minibatch)  # get normal predictions
            loss_t = {
                'loss_mean':
                F.cross_entropy(
                    input=pred_logits,
                    target=labels),  # change to F.mse_loss for regression
                'accuracy_mean':
                softmax_accuracy(preds=F.softmax(pred_logits, -1),
                                 targets=labels)
            }
            loss_map = _add_loss_map(loss_map, loss_t)
            num_samples += minibatch.size(0)

        if 'train' in prefix:  # compute bp and optimize
            loss_t['loss_mean'].backward()
            optimizer.step()

        if args.debug_step:  # for testing purposes
            break

    # compute the mean of the map
    loss_map = _mean_map(loss_map)  # reduce the map to get actual means
    print(
        '{}[Epoch {}][{} samples][{:.2f} sec]: Loss: {:.4f}\tAccuracy: {:.4f}'.
        format(prefix, epoch, num_samples,
               time.time() - start_time, loss_map['loss_mean'].item(),
               loss_map['accuracy_mean'].item() * 100.0))

    # plot the test accuracy, loss and images
    register_plots({**loss_map},
                   grapher,
                   epoch=epoch,
                   prefix='linear' + prefix)
    register_images({'input_imgs': F.upsample(minibatch, size=(100, 100))},
                    grapher,
                    prefix=prefix)

    # return this for early stopping
    loss_val = loss_map['loss_mean'].detach().item()
    loss_map.clear()
    return loss_val
Exemple #3
0
def execute_graph(epoch,
                  model,
                  loader,
                  grapher,
                  optimizer=None,
                  prefix='test'):
    """ execute the graph; wphen 'train' is in the name the model runs the optimizer

    :param epoch: the current epoch number
    :param model: the torch model
    :param loader: the train or **TEST** loader
    :param grapher: the graph writing helper (eg: visdom / tf wrapper)
    :param optimizer: the optimizer
    :param prefix: 'train', 'test' or 'valid'
    :returns: dictionary with scalars
    :rtype: dict

    """
    start_time = time.time()
    is_eval = 'train' not in prefix
    model.eval() if is_eval else model.train()
    assert optimizer is None if is_eval else optimizer is not None
    loss_map, num_samples = {}, 0

    # iterate over data and labels
    for num_minibatches, (augmentation1, augmentation2,
                          labels) in enumerate(loader):
        augmentation1 = augmentation1.cuda(
            non_blocking=True) if args.cuda else augmentation1
        augmentation2 = augmentation2.cuda(
            non_blocking=True) if args.cuda else augmentation2
        labels = labels.cuda(non_blocking=True) if args.cuda else labels

        with torch.no_grad() if is_eval else utils.dummy_context():
            if is_eval and args.polyak_ema > 0:  # use the Polyak model for predictions
                output_dict = layers.get_polyak_prediction(
                    model,
                    pred_fn=functools.partial(model, augmentation1,
                                              augmentation2))
            else:
                output_dict = model(augmentation1,
                                    augmentation2)  # get normal predictions

            # The loss is the BYOL loss + classifer loss (with stop-grad of course).
            byol_loss = loss_function(
                online_prediction1=output_dict['online_prediction1'],
                online_prediction2=output_dict['online_prediction2'],
                target_projection1=output_dict['target_projection1'],
                target_projection2=output_dict['target_projection2'])
            classifier_labels = labels if is_eval else torch.cat(
                [labels, labels], 0)
            classifier_loss = F.cross_entropy(
                input=output_dict['linear_preds'], target=classifier_labels)
            acc1, acc5 = metrics.topk(output=output_dict['linear_preds'],
                                      target=classifier_labels,
                                      topk=(1, 5))

            loss_t = {
                'loss_mean': byol_loss + classifier_loss,
                'byol_loss_mean': byol_loss,
                'linear_loss_mean': classifier_loss,
                'top1_mean': acc1,
                'top5_mean': acc5,
            }
            loss_map = loss_t if not loss_map else tree.map_structure(  # aggregate loss
                _extract_sum_scalars, loss_map, loss_t)
            num_samples += augmentation1.size(0)  # count minibatch samples

        if not is_eval:  # compute bp and optimize
            optimizer.zero_grad()  # zero gradients on optimizer
            if args.half:
                with amp.scale_loss(loss_t['loss_mean'],
                                    optimizer) as scaled_loss:
                    scaled_loss.backward()  # compute grads (fp16+fp32)
            else:
                loss_t['loss_mean'].backward()  # compute grads (fp32)

            if args.clip > 0:
                # TODO: clip by value or norm? torch.nn.utils.clip_grad_value_
                # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                nn.utils.clip_grad_value_(model.parameters(), args.clip)

            optimizer.step()
            if args.polyak_ema > 0:  # update Polyak mean if requested
                layers.polyak_ema_parameters(model, args.polyak_ema)

            del loss_t

        if args.debug_step:  # for testing purposes
            break

    # compute the mean of the dict
    loss_map = tree.map_structure(
        lambda v: v / (num_minibatches + 1),
        loss_map)  # reduce the map to get actual means

    # log some stuff
    to_log = '{}-{}[Epoch {}][{} samples][{:.2f} sec]:\t Loss: {:.4f}\tTop-1: {:.4f}\tTop-5: {:.4f}'
    print(
        to_log.format(prefix, args.distributed_rank, epoch, num_samples,
                      time.time() - start_time, loss_map['loss_mean'].item(),
                      loss_map['top1_mean'].item(),
                      loss_map['top5_mean'].item()))

    # plot the test accuracy, loss and images
    register_plots({**loss_map}, grapher, epoch=epoch, prefix=prefix)

    # tack on images to grapher, making them smaller and only use 64 to not waste network bandwidth
    num_images_to_post = min(64, augmentation1.shape[0])
    image_size_to_post = min(64, augmentation1.shape[-1])
    image_map = {
        'augmentation1_imgs':
        F.interpolate(augmentation1[0:num_images_to_post],
                      size=(image_size_to_post, image_size_to_post)),
        'augmentation2_imgs':
        F.interpolate(augmentation2[0:num_images_to_post],
                      size=(image_size_to_post, image_size_to_post))
    }
    register_images({**image_map}, grapher, prefix=prefix)
    if grapher is not None:
        grapher.save()

    # cleanups (see https://tinyurl.com/ycjre67m) + return loss for early stopping
    loss_val = loss_map['loss_mean'].detach().item()
    loss_map.clear()
    return loss_val
Exemple #4
0
def execute_graph(epoch,
                  model,
                  fisher,
                  data_loader,
                  grapher,
                  optimizer=None,
                  prefix='test'):
    ''' execute the graph; when 'train' is in the name the model runs the optimizer '''
    model.eval() if not 'train' in prefix else model.train()
    assert optimizer is not None if 'train' in prefix else optimizer is None
    loss_map, params, num_samples = {}, {}, 0

    for data, _ in data_loader:
        data = Variable(data).cuda() if args.cuda else Variable(data)

        if 'train' in prefix:
            # zero gradients on optimizer
            # before forward pass
            optimizer.zero_grad()

        with torch.no_grad() if 'train' not in prefix else dummy_context():
            # run the VAE and extract loss
            output_map = model(data)
            loss_t = model.loss_function(output_map, fisher)

        if 'train' in prefix:
            # compute bp and optimize
            loss_t['loss_mean'].backward()
            loss_t[
                'grad_norm_mean'] = torch.norm(  # add norm of vectorized grads to plot
                    nn.utils.parameters_to_vector(model.parameters()))
            optimizer.step()

        with torch.no_grad() if 'train' not in prefix else dummy_context():
            loss_map = _add_loss_map(loss_map, loss_t)
            num_samples += data.size(0)

    loss_map = _mean_map(loss_map)  # reduce the map to get actual means
    print(
        '{}[Epoch {}][{} samples]: Average loss: {:.4f}\tELBO: {:.4f}\tKLD: {:.4f}\tNLL: {:.4f}\tMut: {:.4f}'
        .format(prefix, epoch, num_samples, loss_map['loss_mean'].item(),
                loss_map['elbo_mean'].item(), loss_map['kld_mean'].item(),
                loss_map['nll_mean'].item(), loss_map['mut_info_mean'].item()))

    # gather scalar values of reparameterizers (if they exist)
    reparam_scalars = model.student.get_reparameterizer_scalars()

    # plot the test accuracy, loss and images
    if grapher:  # only if grapher is not None
        register_plots({
            **loss_map,
            **reparam_scalars
        },
                       grapher,
                       epoch=epoch,
                       prefix=prefix)
        images = [
            output_map['augmented']['data'],
            output_map['student']['x_reconstr']
        ]
        img_names = ['original_imgs', 'vae_reconstructions']
        register_images(images, img_names, grapher, prefix=prefix)
        grapher.show()

    # return this for early stopping
    loss_val = {
        'loss_mean': loss_map['loss_mean'].detach().item(),
        'elbo_mean': loss_map['elbo_mean'].detach().item()
    }
    loss_map.clear()
    params.clear()
    return loss_val
Exemple #5
0
def execute_graph(epoch, model, loader, grapher, optimizer=None, prefix='test'):
    """ execute the graph; when 'train' is in the name the model runs the optimizer

    :param epoch: the current epoch number
    :param model: the torch model
    :param loader: the train or **TEST** loader
    :param grapher: the graph writing helper (eg: visdom / tf wrapper)
    :param optimizer: the optimizer
    :param prefix: 'train', 'test' or 'valid'
    :returns: dictionary with scalars
    :rtype: dict

    """
    start_time = time.time()
    is_eval = 'train' not in prefix
    model.eval() if is_eval else model.train()
    assert optimizer is None if is_eval else optimizer is not None
    loss_map, num_samples = {}, 0

    # iterate over data and labels
    for num_minibatches, (minibatch, labels) in enumerate(loader):
        minibatch = minibatch.cuda(non_blocking=True) if args.cuda else minibatch
        labels = labels.cuda(non_blocking=True) if args.cuda else labels

        with torch.no_grad() if is_eval else utils.dummy_context():
            if is_eval and args.polyak_ema > 0:                                # use the Polyak model for predictions
                pred_logits, reparam_map = layers.get_polyak_prediction(
                    model, pred_fn=functools.partial(model, minibatch, labels=labels))
            else:
                pred_logits, reparam_map = model(minibatch, labels=labels)     # get normal predictions

            # compute loss
            loss_t = model.loss_function(recon_x=pred_logits,
                                         x=minibatch,
                                         params=reparam_map,
                                         K=args.monte_carlo_posterior_samples)
            loss_map = loss_t if not loss_map else tree.map_structure(         # aggregate loss
                _extract_sum_scalars, loss_map, loss_t)
            num_samples += minibatch.size(0)                                   # count minibatch samples

        if not is_eval:                                                        # compute bp and optimize
            optimizer.zero_grad()                                              # zero gradients on optimizer
            if args.half:
                with amp.scale_loss(loss_t['loss_mean'], optimizer) as scaled_loss:
                    scaled_loss.backward()                                     # compute grads (fp16+fp32)
            else:
                loss_t['loss_mean'].backward()                                 # compute grads (fp32)

            if args.clip > 0:
                # TODO: clip by value or norm? torch.nn.utils.clip_grad_value_
                # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                nn.utils.clip_grad_value_(model.parameters(), args.clip)

            optimizer.step()                                                   # update the parameters
            if args.polyak_ema > 0:                                            # update Polyak EMA if requested
                layers.polyak_ema_parameters(model, args.polyak_ema)

            del loss_t

        if args.debug_step and num_minibatches > 1:  # for testing purposes
            break

    # compute the mean of the dict
    loss_map = tree.map_structure(
        lambda v: v / (num_minibatches + 1), loss_map)                          # reduce the map to get actual means

    # calculate the true likelihood by marginalizing out the latent variable
    if epoch > 0 and epoch % 20 == 0 and is_eval:
        loss_map['likelihood_mean'] = model.likelihood(loader, K=1000)

    # activate the logits of the reconstruction and get the dict
    reconstr_map = model.get_activated_reconstructions(pred_logits)

    # tack on remote metrics information if requested, do it in-frequently.
    if args.metrics_server is not None and not is_eval:  # only eval train generations due to BN
        request_remote_metrics_calc(epoch, model, grapher, prefix, post_every=20)

    # gather scalar values of reparameterizers (if they exist)
    reparam_scalars = model.get_reparameterizer_scalars()

    # tack on images to grapher
    post_images_every = 5  # TODO(jramapuram): parameterize
    image_map = {'input_imgs': minibatch}

    # Add generations to our image dict
    if epoch > 0 and epoch % post_images_every == 0:
        with torch.no_grad():
            prior_generated = model.generate_synthetic_samples(args.batch_size, reset_state=True,
                                                               use_aggregate_posterior=False)
            ema_generated = model.generate_synthetic_samples(args.batch_size, reset_state=True,
                                                             use_aggregate_posterior=True)
            image_map['prior_generated_imgs'] = prior_generated
            image_map['ema_generated_imgs'] = ema_generated

            # tack on MSSIM information if requested
            if args.calculate_msssim:
                loss_map['prior_gen_msssim_mean'] = metrics.calculate_mssim(
                    minibatch, prior_generated)
                loss_map['ema_gen_msssim_mean'] = metrics.calculate_mssim(
                    minibatch, ema_generated)

    # plot the test accuracy, loss and images
    register_plots({**loss_map, **reparam_scalars}, grapher, epoch=epoch, prefix=prefix)

    def reduce_num_images(struct, num): return tree.map_structure(lambda v: v[0:num], struct)
    register_images(reduce_num_images({**image_map, **reconstr_map}, num=64),
                    grapher, epoch=epoch, post_every=post_images_every, prefix=prefix)
    if grapher is not None:
        grapher.save()

    # log some stuff
    def tensor2item(t): return t.detach().item() if isinstance(t, torch.Tensor) else t
    to_log = '{}-{}[Epoch {}][{} samples][{:.2f} sec]:\t Loss: {:.4f}\t-ELBO: {:.4f}\tNLL: {:.4f}\tKLD: {:.4f}\tMI: {:.4f}'
    print(to_log.format(
        prefix, args.distributed_rank, epoch, num_samples, time.time() - start_time,
        tensor2item(loss_map['loss_mean']),
        tensor2item(loss_map['elbo_mean']),
        tensor2item(loss_map['nll_mean']),
        tensor2item(loss_map['kld_mean']),
        tensor2item(loss_map['mut_info_mean'])))

    # cleanups (see https://tinyurl.com/ycjre67m) + return ELBO for early stopping
    loss_val = tensor2item(loss_map['elbo_mean']) if args.vae_type != 'autoencoder' \
        else tensor2item(loss_map['nll_mean'])
    for d in [loss_map, image_map, reparam_map, reparam_scalars]:
        d.clear()

    del minibatch
    del labels

    return loss_val
Exemple #6
0
def execute_graph(epoch,
                  model,
                  data_loader,
                  grapher,
                  optimizer=None,
                  prefix='test',
                  plot_mem=False):
    ''' execute the graph; when 'train' is in the name the model runs the optimizer '''
    start_time = time.time()
    model.eval() if not 'train' in prefix else model.train()
    assert optimizer is not None if 'train' in prefix else optimizer is None
    loss_map, num_samples = {}, 0
    x_original, x_related = None, None

    for item in data_loader:
        # first destructure the data, cuda-ize and wrap in vars
        x_original, x_related, labels = _unpack_data_and_labels(item)
        x_related, labels = cudaize(x_related,
                                    is_data_tensor=True), cudaize(labels)

        if 'train' in prefix:  # zero gradients on optimizer
            optimizer.zero_grad()

        with torch.no_grad() if 'train' not in prefix else dummy_context():
            with torch.autograd.detect_anomaly(
            ) if args.detect_anomalies else dummy_context():
                x_original, x_related = generate_related(
                    x_related, x_original, args)
                #x_original = cudaize(x_original, is_data_tensor=True)

                # run the model and gather the loss map
                data_to_infer = x_original if args.use_full_resolution else x_related
                loss_logits_t = model(data_to_infer)
                loss_t = {
                    'loss_mean':
                    F.cross_entropy(input=loss_logits_t, target=labels)
                }

                # compute accuracy and aggregate into map
                loss_t['accuracy_mean'] = softmax_accuracy(F.softmax(
                    loss_logits_t, -1),
                                                           labels,
                                                           size_average=True)

                loss_map = _add_loss_map(loss_map, loss_t)
                num_samples += x_related.size(0)

        if 'train' in prefix:  # compute bp and optimize
            if args.half is True:
                optimizer.backward(loss_t['loss_mean'])
                # with amp_handle.scale_loss(loss_t['loss_mean'], optimizer,
                #                            dynamic_loss_scale=True) as scaled_loss:
                #     scaled_loss.backward()
            else:
                loss_t['loss_mean'].backward()

            if args.clip > 0:
                # TODO: clip by value or norm? torch.nn.utils.clip_grad_value_
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) \
                    if not args.half is True else optimizer.clip_master_grads(args.clip)

            optimizer.step()
            del loss_t

    loss_map = _mean_map(loss_map)  # reduce the map to get actual means
    correct_percent = 100.0 * loss_map['accuracy_mean']

    print(
        '''{}[Epoch {}][{} samples][{:.2f} sec]:Average loss: {:.4f}\tAcc: {:.4f}'''
        .format(prefix, epoch, num_samples,
                time.time() - start_time, loss_map['loss_mean'].item(),
                correct_percent))

    # add memory tracking
    if plot_mem:
        process = psutil.Process(os.getpid())
        loss_map['cpumem_scalar'] = process.memory_info().rss * 1e-6
        loss_map['cudamem_scalar'] = torch.cuda.memory_allocated() * 1e-6

    # plot all the scalar / mean values
    register_plots(loss_map, grapher, epoch=epoch, prefix=prefix)

    # plot images, crops, inlays and all relevant images
    def resize_4d_or_5d(img):
        if len(img.shape) == 4:
            return F.interpolate(img, (32, 32),
                                 mode='bilinear',
                                 align_corners=True)
        elif len(img.shape) == 5:
            return torch.cat([
                F.interpolate(img[:, i, :, :, :], (32, 32),
                              mode='bilinear',
                              align_corners=True) for i in range(img.shape[1])
            ], 0)
        else:
            raise Exception("only 4d or 5d images supported")

    # input_imgs_map = {
    #     'related_imgs': F.interpolate(x_related, (32, 32), mode='bilinear', align_corners=True),
    #     'original_imgs': F.interpolate(x_original, (32, 32), mode='bilinear', align_corners=True)
    # }
    input_imgs_map = {
        'related_imgs': resize_4d_or_5d(x_related),
        'original_imgs': resize_4d_or_5d(x_original)
    }
    register_images(input_imgs_map, grapher, prefix=prefix)
    grapher.show()

    # return this for early stopping
    loss_val = {
        'loss_mean': loss_map['loss_mean'].clone().detach().item(),
        'acc_mean': correct_percent
    }

    # delete the data instances, see https://tinyurl.com/ycjre67m
    loss_map.clear()
    input_imgs_map.clear()
    del loss_map
    del input_imgs_map
    del x_related
    del x_original
    del labels
    gc.collect()

    # return loss and accuracy
    return loss_val
Exemple #7
0
def execute_graph(epoch, model, loader, grapher, optimizer=None, prefix='test'):
    """ execute the graph; when 'train' is in the name the model runs the optimizer

    :param epoch: the current epoch number
    :param model: the torch model
    :param loader: the train or **TEST** loader
    :param grapher: the graph writing helper (eg: visdom / tf wrapper)
    :param optimizer: the optimizer
    :param prefix: 'train', 'test' or 'valid'
    :returns: dictionary with scalars
    :rtype: dict

    """
    start_time = time.time()
    model.eval() if prefix == 'test' else model.train()
    assert optimizer is not None if 'train' in prefix or 'valid' in prefix else optimizer is None
    loss_map, num_samples, print_once = {}, 0, False

    # iterate over data and labels
    for minibatch, labels in loader:
        minibatch = minibatch.cuda() if args.cuda else minibatch
        labels = labels.cuda() if args.cuda else labels
        if args.half:
            minibatch = minibatch.half()

        if 'train' in prefix:
            optimizer.zero_grad()                                                # zero gradients on optimizer

        with torch.no_grad() if prefix == 'test' else dummy_context():
            pred_logits, reparam_map = model(minibatch)                          # get normal predictions
            loss_t = model.loss_function(pred_logits, minibatch, reparam_map)
            loss_map = _add_loss_map(loss_map, loss_t)
            num_samples += minibatch.size(0)

        if 'train' in prefix: # compute bp and optimize
            if args.half:
                optimizer.backward(loss_t['loss_mean'])
                # with amp_handle.scale_loss(loss_t['loss_mean'], optimizer,
                #                            dynamic_loss_scale=True) as scaled_loss:
                #     scaled_loss.backward()
            else:
                loss_t['loss_mean'].backward()

            if args.clip > 0:
                # TODO: clip by value or norm? torch.nn.utils.clip_grad_value_
                # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) \
                nn.utils.clip_grad_value_(model.parameters(), args.clip) \
                    if not args.half else optimizer.clip_master_grads(args.clip)

            optimizer.step()
            del loss_t

        if args.debug_step: # for testing purposes
            break

    # compute the mean of the map
    loss_map = _mean_map(loss_map) # reduce the map to get actual means
    print('{}[Epoch {}][{} samples][{:.2f} sec]:\t Loss: {:.4f}\t-ELBO: {:.4f}\tNLL: {:.4f}\tKLD: {:.4f}\tMI: {:.4f}'.format(
        prefix, epoch, num_samples, time.time() - start_time,
        loss_map['loss_mean'].item(),
        loss_map['elbo_mean'].item(),
        loss_map['nll_mean'].item(),
        loss_map['kld_mean'].item(),
        loss_map['mut_info_mean'].item()))

    # activate the logits of the reconstruction and get the dict
    reconstr_map = model.get_activated_reconstructions(pred_logits)

    # tack on MSSIM information if requested
    if args.calculate_msssim:
        loss_map['ms_ssim_mean'] = compute_mssim(reconstr_image, minibatch)

    # gather scalar values of reparameterizers (if they exist)
    reparam_scalars = model.get_reparameterizer_scalars()

    # plot the test accuracy, loss and images
    register_plots({**loss_map, **reparam_scalars}, grapher, epoch=epoch, prefix=prefix)

    # get some generations, only do once in a while for pixelcnn
    generated = None
    if args.decoder_layer_type == 'pixelcnn' and epoch % 10 == 0:
        generated = model.generate_synthetic_samples(args.batch_size, reset_state=True,
                                                     use_aggregate_posterior=args.use_aggregate_posterior)
    elif args.decoder_layer_type != 'pixelcnn':
        generated = model.generate_synthetic_samples(args.batch_size, reset_state=True,
                                                     use_aggregate_posterior=args.use_aggregate_posterior)

    # tack on images to grapher
    image_map = {
        'input_imgs': F.upsample(minibatch, (100, 100)) if args.task == 'image_folder' else minibatch
    }
    if generated is not None:
        image_map['generated_imgs'] = F.upsample(generated, (100, 100)) \
            if args.task == 'image_folder' else generated

    register_images({**image_map, **reconstr_map}, grapher, prefix=prefix)
    grapher.save()

    # cleanups (see https://tinyurl.com/ycjre67m) + return ELBO for early stopping
    loss_val = loss_map['elbo_mean'].detach().item()
    loss_map.clear(); image_map.clear(); reparam_map.clear(); reparam_scalars.clear()
    del minibatch; del labels
    return loss_val
Exemple #8
0
def execute_graph(epoch,
                  model,
                  data_loader,
                  grapher,
                  optimizer=None,
                  prefix='test',
                  plot_mem=False):
    ''' execute the graph; when 'train' is in the name the model runs the optimizer '''
    start_time = time.time()
    model.eval() if not 'train' in prefix else model.train()
    assert optimizer is not None if 'train' in prefix else optimizer is None
    loss_map, num_samples = {}, 0
    x_original, x_related = None, None

    for item in data_loader:
        # first destructure the data, cuda-ize and wrap in vars
        x_original, x_related, labels = _unpack_data_and_labels(item)
        x_related, labels = cudaize(x_related,
                                    is_data_tensor=True), cudaize(labels)

        if 'train' in prefix:  # zero gradients on optimizer
            optimizer.zero_grad()

        with torch.no_grad() if 'train' not in prefix else dummy_context():
            with torch.autograd.detect_anomaly(
            ) if args.detect_anomalies else dummy_context():
                x_original, x_related = generate_related(
                    x_related, x_original, args)
                x_original = cudaize(x_original, is_data_tensor=True)

                # run the model and gather the loss map
                output_map = model(x_original, x_related)
                loss_t = model.loss_function(x_related, labels, output_map)

                # compute accuracy and aggregate into map
                accuracy_fn = softmax_accuracy if len(
                    labels.shape) == 1 else bce_accuracy
                loss_t['accuracy_mean'] = accuracy_fn(F.softmax(
                    output_map['preds'], -1),
                                                      labels,
                                                      size_average=True)

                loss_map = _add_loss_map(loss_map, loss_t)
                num_samples += x_related.size(0)

        if 'train' in prefix:  # compute bp and optimize
            if args.half is True:
                optimizer.backward(loss_t['loss_mean'])
                # with amp_handle.scale_loss(loss_t['loss_mean'], optimizer,
                #                            dynamic_loss_scale=True) as scaled_loss:
                #     scaled_loss.backward()
            else:
                loss_t['loss_mean'].backward()

            if args.clip > 0:
                # TODO: clip by value or norm? torch.nn.utils.clip_grad_value_
                # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) \
                torch.nn.utils.clip_grad_value_(model.parameters(), args.clip) \
                    if not args.half is True else optimizer.clip_master_grads(args.clip)

            optimizer.step()
            del loss_t

    loss_map = _mean_map(loss_map)  # reduce the map to get actual means
    correct_percent = 100.0 * loss_map['accuracy_mean']

    print('{}[Epoch {}][{} samples][{:.2f} sec]:\
    Average loss: {:.4f}\tKLD: {:.4f}\t\
    NLL: {:.4f}\tAcc: {:.4f}'.format(prefix, epoch, num_samples,
                                     time.time() - start_time,
                                     loss_map['loss_mean'].item(),
                                     loss_map['kld_mean'].item(),
                                     loss_map['nll_mean'].item(),
                                     correct_percent))

    # gather scalar values of reparameterizers (if they exist)
    reparam_scalars = model.vae.get_reparameterizer_scalars()

    # add memory tracking
    if plot_mem:
        process = psutil.Process(os.getpid())
        loss_map['cpumem_scalar'] = process.memory_info().rss * 1e-6
        loss_map['cudamem_scalar'] = torch.cuda.memory_allocated() * 1e-6

    # plot all the scalar / mean values
    register_plots({
        **loss_map,
        **reparam_scalars
    },
                   grapher,
                   epoch=epoch,
                   prefix=prefix)

    # plot images, crops, inlays and all relevant images
    input_imgs_map = {'related_imgs': x_related, 'original_imgs': x_original}
    imgs_map = model.get_imgs(x_related.size(0), output_map, input_imgs_map)
    register_images(imgs_map, grapher, prefix=prefix)

    # return this for early stopping
    loss_val = {
        'loss_mean': loss_map['loss_mean'].clone().detach().item(),
        'pred_loss_mean': loss_map['pred_loss_mean'].clone().detach().item(),
        'accuracy_mean': correct_percent
    }

    # delete the data instances, see https://tinyurl.com/ycjre67m
    loss_map.clear(), input_imgs_map.clear(), imgs_map.clear()
    output_map.clear(), reparam_scalars.clear()
    del loss_map
    del input_imgs_map
    del imgs_map
    del output_map
    del reparam_scalars
    del x_related
    del x_original
    del labels
    gc.collect()

    # return loss scalar map
    return loss_val