Exemplo n.º 1
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    train_transform, test_transform = build_train_and_test_transforms()
    loader_dict = {'train_transform': train_transform,
                   'test_transform': test_transform,
                   **vars(args)}
    loader = get_loader(**loader_dict)

    # set the input tensor shape (ignoring batch dimension) and related dataset sizing
    args.input_shape = loader.input_shape
    args.num_train_samples = loader.num_train_samples // args.num_replicas
    args.num_test_samples = loader.num_test_samples  # Test isn't currently split across devices
    args.num_valid_samples = loader.num_valid_samples // args.num_replicas
    args.steps_per_train_epoch = args.num_train_samples // args.batch_size  # drop-remainder
    args.total_train_steps = args.epochs * args.steps_per_train_epoch

    # build the network
    network = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=loader.output_size)
    network = nn.SyncBatchNorm.convert_sync_batchnorm(network) if args.convert_to_sync_bn else network
    network = torch.jit.script(network) if args.jit else network
    network = network.cuda() if args.cuda else network
    lazy_generate_modules(network, loader.train_loader)
    network = layers.init_weights(network, init=args.weight_initialization)

    if args.num_replicas > 1:
        print("wrapping model with DDP...")
        network = layers.DistributedDataParallelPassthrough(network,
                                                            device_ids=[0],   # set w/cuda environ var
                                                            output_device=0,  # set w/cuda environ var
                                                            find_unused_parameters=True)

    # Get some info about the structure and number of params.
    print(network)
    print("model has {} million parameters.".format(
        utils.number_of_parameters(network) / 1e6
    ))

    # build the grapher object
    grapher = None
    if args.visdom_url is not None and args.distributed_rank == 0:
        grapher = Grapher('visdom', env=utils.get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port,
                          log_folder=args.log_dir)
    elif args.distributed_rank == 0:
        grapher = Grapher(
            'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args)))

    return loader, network, grapher
Exemplo n.º 2
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    resize_shape = (args.image_size_override, args.image_size_override)
    transform = [torchvision.transforms.Resize(resize_shape)] \
        if args.image_size_override else None
    loader = get_loader(args, transform=transform, **vars(args))  # build the loader
    args.input_shape = loader.img_shp if args.image_size_override is None \
        else [loader.img_shp[0], *resize_shape]                   # set the input size

    # build the network
    vae_dict = {
        'simple': SimpleVAE,
        'msg': MSGVAE,
        'parallel': ParallellyReparameterizedVAE,
        'sequential': SequentiallyReparameterizedVAE,
        'vrnn': VRNN
    }
    network = vae_dict[args.vae_type](loader.img_shp, kwargs=deepcopy(vars(args)))
    lazy_generate_modules(network, loader.train_loader)
    network = network.cuda() if args.cuda else network
    network = append_save_and_load_fns(network, prefix="VAE_")
    if args.ngpu > 1:
        print("data-paralleling...")
        network.parallel()

    # build the grapher object
    if args.visdom_url:
        grapher = Grapher('visdom', env=get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=get_name(args))

    return loader, network, grapher
Exemplo n.º 3
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    resize_shape = (args.image_size_override, args.image_size_override)
    transform = [transforms.Resize(resize_shape)] \
        if args.image_size_override else None
    loader = get_loader(args, transform=transform,
                        **vars(args))  # build the loader
    args.input_shape = loader.img_shp if args.image_size_override is None \
        else [loader.img_shp[0], *resize_shape]                   # set the input size

    # build the network; to use your own model import and construct it here
    network = resnet18(num_classes=loader.output_size)
    lazy_generate_modules(network, loader.train_loader)
    network = network.cuda() if args.cuda else network
    network = append_save_and_load_fns(network, prefix="VAE_")
    if args.ngpu > 1:
        print("data-paralleling...")
        network.parallel()

    # build the grapher object
    if args.visdom_url:
        grapher = Grapher('visdom',
                          env=get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port)
    else:
        grapher = Grapher('tensorboard', comment=get_name(args))

    return loader, network, grapher
Exemplo n.º 4
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    train_transform, test_transform = build_train_and_test_transforms()
    loader_dict = {'train_transform': train_transform,
                   'test_transform': test_transform, **vars(args)}
    loader = get_loader(**loader_dict)

    # set the input tensor shape (ignoring batch dimension) and related dataset sizing
    args.input_shape = loader.input_shape
    args.output_size = loader.output_size
    args.num_train_samples = loader.num_train_samples // args.num_replicas
    args.num_test_samples = loader.num_test_samples  # Test isn't currently split across devices
    args.num_valid_samples = loader.num_valid_samples // args.num_replicas
    args.steps_per_train_epoch = args.num_train_samples // args.batch_size  # drop-remainder
    args.total_train_steps = args.epochs * args.steps_per_train_epoch

    # build the network
    network = build_vae(args.vae_type)(loader.input_shape, kwargs=deepcopy(vars(args)))
    network = network.cuda() if args.cuda else network
    lazy_generate_modules(network, loader.train_loader)
    network = layers.init_weights(network, init=args.weight_initialization)

    if args.num_replicas > 1:
        print("wrapping model with DDP...")
        network = layers.DistributedDataParallelPassthrough(network,
                                                            device_ids=[0],   # set w/cuda environ var
                                                            output_device=0,  # set w/cuda environ var
                                                            find_unused_parameters=True)

    # Get some info about the structure and number of params.
    print(network)
    print("model has {} million parameters.".format(
        utils.number_of_parameters(network) / 1e6
    ))

    # add the test set as a np array for metrics calc
    if args.metrics_server is not None:
        network.test_images = get_numpy_dataset(task=args.task,
                                                data_dir=args.data_dir,
                                                test_transform=test_transform,
                                                split='test',
                                                image_size=args.image_size_override,
                                                cuda=args.cuda)
        print("Metrics test images: ", network.test_images.shape)

    # build the grapher object
    grapher = None
    if args.visdom_url is not None and args.distributed_rank == 0:
        grapher = Grapher('visdom', env=utils.get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port,
                          log_folder=args.log_dir)
    elif args.distributed_rank == 0:
        grapher = Grapher(
            'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args)))

    return loader, network, grapher
Exemplo n.º 5
0
def execute_graph(epoch,
                  model,
                  loader,
                  grapher,
                  optimizer=None,
                  prefix='test'):
    """ execute the graph; when 'train' is in the name the model runs the optimizer

    :param epoch: the current epoch number
    :param model: the torch model
    :param loader: the train or **TEST** loader
    :param grapher: the graph writing helper (eg: visdom / tf wrapper)
    :param optimizer: the optimizer
    :param prefix: 'train', 'test' or 'valid'
    :returns: dictionary with scalars
    :rtype: dict

    """
    start_time = time.time()
    is_eval = 'train' not in prefix
    model.eval() if is_eval else model.train()
    loss_map, num_samples = {}, 0

    # iterate over data and labels
    for num_minibatches, (minibatch, labels) in enumerate(loader):
        minibatch = minibatch.cuda(
            non_blocking=True) if args.cuda else minibatch
        labels = labels.cuda(non_blocking=True) if args.cuda else labels

        with torch.no_grad():
            if is_eval and args.polyak_ema > 0:  # use the Polyak model for predictions
                pred_logits, reparam_map = layers.get_polyak_prediction(
                    model,
                    pred_fn=functools.partial(model, minibatch, labels=labels))
            else:
                pred_logits, reparam_map = model(
                    minibatch, labels=labels)  # get normal predictions

            # compute loss
            loss_t = model.loss_function(recon_x=pred_logits,
                                         x=minibatch,
                                         params=reparam_map,
                                         K=args.monte_carlo_posterior_samples)
            loss_map = loss_t if not loss_map else tree.map_structure(  # aggregate loss
                _extract_sum_scalars, loss_map, loss_t)
            num_samples += minibatch.size(0)  # count minibatch samples
            del loss_t

        if args.debug_step and num_minibatches > 1:  # for testing purposes
            break

    # compute the mean of the dict
    loss_map = tree.map_structure(
        lambda v: v / (num_minibatches + 1),
        loss_map)  # reduce the map to get actual means

    # log some stuff
    def tensor2item(t):
        return t.detach().item() if isinstance(t, torch.Tensor) else t

    to_log = '{}-{}[Epoch {}][{} samples][{:.2f} sec]:\t Loss: {:.4f}\t-ELBO: {:.4f}\tNLL: {:.4f}\tKLD: {:.4f}\tMI: {:.4f}'
    print(
        to_log.format(prefix, args.distributed_rank, epoch, num_samples,
                      time.time() - start_time,
                      tensor2item(loss_map['loss_mean']),
                      tensor2item(loss_map['elbo_mean']),
                      tensor2item(loss_map['nll_mean']),
                      tensor2item(loss_map['kld_mean']),
                      tensor2item(loss_map['mut_info_mean'])))

    # build the image map
    image_map = {'input_imgs': minibatch}

    # activate the logits of the reconstruction and get the dict
    image_map = {
        **image_map,
        **model.get_activated_reconstructions(pred_logits)
    }

    # tack on remote metrics information if requested, do it in-frequently.
    if args.metrics_server is not None:
        request_remote_metrics_calc(epoch, model, grapher, prefix)

    # Add generations to our image dict
    with torch.no_grad():
        prior_generated = model.generate_synthetic_samples(
            10000, reset_state=True, use_aggregate_posterior=False)
        ema_generated = model.generate_synthetic_samples(
            10000, reset_state=True, use_aggregate_posterior=True)
        image_map['prior'] = prior_generated
        image_map['ema'] = ema_generated

        # tack on MSSIM information if requested
        if args.calculate_msssim:
            loss_map['prior_gen_msssim_mean'] = metrics.calculate_mssim(
                minibatch, prior_generated[0:minibatch.shape[0]])
            loss_map['ema_gen_msssim_mean'] = metrics.calculate_mssim(
                minibatch, ema_generated[0:minibatch.shape[0]])

    # save all the images
    image_dir = os.path.join(args.log_dir, utils.get_name(args), 'images')
    os.makedirs(image_dir, exist_ok=True)

    for k, v in image_map.items():
        grid = torchvision.utils.make_grid(v, normalize=True, scale_each=True)
        grid_filename = os.path.join(image_dir, "{}.png".format(k))
        transforms.ToPILImage()(grid.cpu()).save(grid_filename)

        for idx, sample in enumerate(v):
            current_filename = os.path.join(image_dir,
                                            "{}_{}.png".format(idx, k))
            transforms.ToPILImage()(sample.cpu()).save(current_filename)

    # cleanups (see https://tinyurl.com/ycjre67m) + return ELBO for early stopping
    for d in [loss_map, image_map, reparam_map]:
        d.clear()

    del minibatch
    del labels