Esempio n. 1
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    train_transform, test_transform = build_train_and_test_transforms()
    loader_dict = {'train_transform': train_transform,
                   'test_transform': test_transform,
                   **vars(args)}
    loader = get_loader(**loader_dict)

    # set the input tensor shape (ignoring batch dimension) and related dataset sizing
    args.input_shape = loader.input_shape
    args.num_train_samples = loader.num_train_samples // args.num_replicas
    args.num_test_samples = loader.num_test_samples  # Test isn't currently split across devices
    args.num_valid_samples = loader.num_valid_samples // args.num_replicas
    args.steps_per_train_epoch = args.num_train_samples // args.batch_size  # drop-remainder
    args.total_train_steps = args.epochs * args.steps_per_train_epoch

    # build the network
    network = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=loader.output_size)
    network = nn.SyncBatchNorm.convert_sync_batchnorm(network) if args.convert_to_sync_bn else network
    network = torch.jit.script(network) if args.jit else network
    network = network.cuda() if args.cuda else network
    lazy_generate_modules(network, loader.train_loader)
    network = layers.init_weights(network, init=args.weight_initialization)

    if args.num_replicas > 1:
        print("wrapping model with DDP...")
        network = layers.DistributedDataParallelPassthrough(network,
                                                            device_ids=[0],   # set w/cuda environ var
                                                            output_device=0,  # set w/cuda environ var
                                                            find_unused_parameters=True)

    # Get some info about the structure and number of params.
    print(network)
    print("model has {} million parameters.".format(
        utils.number_of_parameters(network) / 1e6
    ))

    # build the grapher object
    grapher = None
    if args.visdom_url is not None and args.distributed_rank == 0:
        grapher = Grapher('visdom', env=utils.get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port,
                          log_folder=args.log_dir)
    elif args.distributed_rank == 0:
        grapher = Grapher(
            'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args)))

    return loader, network, grapher
Esempio n. 2
0
    def build_encoder(self):
        """ helper to build the encoder type

        :returns: an encoder
        :rtype: nn.Module

        """
        encoder = layers.get_encoder(**self.config)(
            output_size=self.reparameterizer.input_size)
        print('encoder has {} parameters\n'.format(
            utils.number_of_parameters(encoder) / 1e6))
        return torch.jit.script(encoder) if self.config['jit'] else encoder
Esempio n. 3
0
    def build_decoder(self, reupsample=True):
        """ helper function to build convolutional or dense decoder

        :returns: a decoder
        :rtype: nn.Module

        """
        dec_conf = deepcopy(self.config)
        if dec_conf['nll_type'] == 'pixel_wise':
            dec_conf['input_shape'][0] *= 256

        decoder = layers.get_decoder(
            output_shape=dec_conf['input_shape'],
            **dec_conf)(input_size=self.reparameterizer.output_size)
        print('decoder has {} parameters\n'.format(
            utils.number_of_parameters(decoder) / 1e6))

        # append the variance as necessary
        decoder = self._append_variance_projection(decoder)
        return torch.jit.script(decoder) if self.config['jit'] else decoder
Esempio n. 4
0
def run(args):
    # collect our model and data loader
    model, loader, grapher = get_model_and_loader()
    print("model has {} params".format(number_of_parameters(model)))

    # collect our optimizer
    optimizer = build_optimizer(model)

    # train the VAE on the same distributions as the model pool
    if args.restore is None:
        print("training current distribution for {} epochs".format(
            args.epochs))
        early = EarlyStopping(model, burn_in_interval=100,
                              max_steps=80) if args.early_stop else None

        test_map = {}
        for epoch in range(1, args.epochs + 1):
            generate(epoch, model, grapher)
            train(epoch, model, optimizer, loader.train_loader, grapher)
            test_map = test(epoch, model, loader.test_loader, grapher)

            if args.early_stop and early(test_map['pred_loss_mean']):
                early.restore()  # restore and test again
                test_map = test(epoch, model, loader.test_loader, grapher)
                break

            # adjust the LR if using momentum sgd
            if args.optimizer == 'sgd_momentum':
                decay_lr_every(optimizer, args.lr, epoch)

        grapher.save()  # save to endpoint after training
    else:
        assert model.load(args.restore), "Failed to load model"
        test_loss, test_acc = test(epoch, model, loader.test_loader, grapher)

    # evaluate one-time metrics
    scalar_map_to_csvs(test_map)

    # cleanups
    grapher.close()
Esempio n. 5
0
def run(args):
    # collect our model and data loader
    model, loader, grapher = get_model_and_loader()
    print("model has {} params".format(number_of_parameters(model)))

    # collect our optimizer
    optimizer = build_optimizer(model)

    # train the VAE on the same distributions as the model pool
    if args.restore is None:
        print("training current distribution for {} epochs".format(
            args.epochs))
        early = EarlyStopping(model, burn_in_interval=100,
                              max_steps=80) if args.early_stop else None

        test_loss, test_acc = 0.0, 0.0
        for epoch in range(1, args.epochs + 1):
            train(epoch, model, optimizer, loader.train_loader, grapher)
            test_loss = test(epoch, model, loader.test_loader, grapher)

            if args.early_stop and early(test_loss['loss_mean']):
                early.restore()  # restore and test+generate again
                test_loss = test(epoch, model, loader.test_loader, grapher)
                break

            # adjust the LR if using momentum sgd
            if args.optimizer == 'sgd_momentum':
                decay_lr_every(optimizer, args.lr, epoch)

        grapher.save()  # save to endpoint after training
    else:
        model = torch.load(args.restore)
        test_loss = test(epoch, model, loader.test_loader, grapher)

    # evaluate one-time metrics
    append_to_csv([test_loss['acc_mean']], "{}_test_acc.csv".format(args.uid))

    # cleanups
    grapher.close()
Esempio n. 6
0
def train_loop(data_loaders, model, fid_model, grapher, args):
    ''' simple helper to run the entire train loop; not needed for eval modes'''
    optimizer = build_optimizer(model.student)  # collect our optimizer
    print(
        "there are {} params with {} elems in the st-model and {} params in the student with {} elems"
        .format(len(list(model.parameters())), number_of_parameters(model),
                len(list(model.student.parameters())),
                number_of_parameters(model.student)))

    # main training loop
    fisher = None
    for j, loader in enumerate(data_loaders):
        num_epochs = args.epochs  # TODO: randomize epochs by something like: + np.random.randint(0, 13)
        print("training current distribution for {} epochs".format(num_epochs))
        early = EarlyStopping(
            model, max_steps=50,
            burn_in_interval=None) if args.early_stop else None
        #burn_in_interval=int(num_epochs*0.2)) if args.early_stop else None

        test_loss = None
        for epoch in range(1, num_epochs + 1):
            train(epoch, model, fisher, optimizer, loader.train_loader,
                  grapher)
            test_loss = test(epoch, model, fisher, loader.test_loader, grapher)
            if args.early_stop and early(test_loss['loss_mean']):
                early.restore()  # restore and test+generate again
                test_loss = test_and_generate(epoch, model, fisher, loader,
                                              grapher)
                break

            generate(model, grapher, 'student')  # generate student samples
            generate(model, grapher, 'teacher')  # generate teacher samples

        # evaluate and save away one-time metrics, these include:
        #    1. test elbo
        #    2. FID
        #    3. consistency
        #    4. num synth + num true samples
        #    5. dump config to visdom
        check_or_create_dir(os.path.join(args.output_dir))
        append_to_csv([test_loss['elbo_mean']],
                      os.path.join(args.output_dir,
                                   "{}_test_elbo.csv".format(args.uid)))
        append_to_csv([test_loss['elbo_mean']],
                      os.path.join(args.output_dir,
                                   "{}_test_elbo.csv".format(args.uid)))
        num_synth_samples = np.ceil(epoch * args.batch_size * model.ratio)
        num_true_samples = np.ceil(epoch * (args.batch_size -
                                            (args.batch_size * model.ratio)))
        append_to_csv([num_synth_samples],
                      os.path.join(args.output_dir,
                                   "{}_numsynth.csv".format(args.uid)))
        append_to_csv([num_true_samples],
                      os.path.join(args.output_dir,
                                   "{}_numtrue.csv".format(args.uid)))
        append_to_csv([epoch],
                      os.path.join(args.output_dir,
                                   "{}_epochs.csv".format(args.uid)))
        grapher.vis.text(num_synth_samples,
                         opts=dict(title="num_synthetic_samples"))
        grapher.vis.text(num_true_samples, opts=dict(title="num_true_samples"))
        grapher.vis.text(pprint.PrettyPrinter(indent=4).pformat(
            model.student.config),
                         opts=dict(title="config"))

        # calc the consistency using the **PREVIOUS** loader
        if j > 0:
            append_to_csv(
                calculate_consistency(model, data_loaders[j - 1],
                                      args.reparam_type, args.vae_type,
                                      args.cuda),
                os.path.join(args.output_dir,
                             "{}_consistency.csv".format(args.uid)))

        if args.calculate_fid_with is not None:
            # TODO: parameterize num fid samples, currently use less for inceptionv3 as it's COSTLY
            num_fid_samples = 4000 if args.calculate_fid_with != 'inceptionv3' else 1000
            append_to_csv(
                calculate_fid(fid_model=fid_model,
                              model=model,
                              loader=loader,
                              grapher=grapher,
                              num_samples=num_fid_samples,
                              cuda=args.cuda),
                os.path.join(args.output_dir, "{}_fid.csv".format(args.uid)))

        grapher.save()  # save the remote visdom graphs
        if j != len(data_loaders) - 1:
            if args.ewc_gamma > 0:
                # calculate the fisher from the previous data loader
                print("computing fisher info matrix....")
                fisher_tmp = estimate_fisher(
                    model.student,  # this is pre-fork
                    loader,
                    args.batch_size,
                    cuda=args.cuda)
                if fisher is not None:
                    assert len(fisher) == len(
                        fisher_tmp), "#fisher params != #new fisher params"
                    for (kf, vf), (kft, vft) in zip(fisher.items(),
                                                    fisher_tmp.items()):
                        fisher[kf] += fisher_tmp[kft]
                else:
                    fisher = fisher_tmp

            # spawn a new student & rebuild grapher; we also pass
            # the new model's parameters through a new optimizer.
            if not args.disable_student_teacher:
                model.fork()
                lazy_generate_modules(model, data_loaders[0].img_shp)
                optimizer = build_optimizer(model.student)
                print(
                    "there are {} params with {} elems in the st-model and {} params in the student with {} elems"
                    .format(len(list(model.parameters())),
                            number_of_parameters(model),
                            len(list(model.student.parameters())),
                            number_of_parameters(model.student)))

            else:
                # increment anyway for vanilla models
                # so that we can have a separate visdom env
                model.current_model += 1

            grapher = Grapher(env=model.get_name(),
                              server=args.visdom_url,
                              port=args.visdom_port)
Esempio n. 7
0
def build_loader_model_grapher(args):
    """builds a model, a dataloader and a grapher

    :param args: argparse
    :param transform: the dataloader transform
    :returns: a dataloader, a grapher and a model
    :rtype: list

    """
    train_transform, test_transform = build_train_and_test_transforms()
    loader_dict = {'train_transform': train_transform,
                   'test_transform': test_transform, **vars(args)}
    loader = get_loader(**loader_dict)

    # set the input tensor shape (ignoring batch dimension) and related dataset sizing
    args.input_shape = loader.input_shape
    args.output_size = loader.output_size
    args.num_train_samples = loader.num_train_samples // args.num_replicas
    args.num_test_samples = loader.num_test_samples  # Test isn't currently split across devices
    args.num_valid_samples = loader.num_valid_samples // args.num_replicas
    args.steps_per_train_epoch = args.num_train_samples // args.batch_size  # drop-remainder
    args.total_train_steps = args.epochs * args.steps_per_train_epoch

    # build the network
    network = build_vae(args.vae_type)(loader.input_shape, kwargs=deepcopy(vars(args)))
    network = network.cuda() if args.cuda else network
    lazy_generate_modules(network, loader.train_loader)
    network = layers.init_weights(network, init=args.weight_initialization)

    if args.num_replicas > 1:
        print("wrapping model with DDP...")
        network = layers.DistributedDataParallelPassthrough(network,
                                                            device_ids=[0],   # set w/cuda environ var
                                                            output_device=0,  # set w/cuda environ var
                                                            find_unused_parameters=True)

    # Get some info about the structure and number of params.
    print(network)
    print("model has {} million parameters.".format(
        utils.number_of_parameters(network) / 1e6
    ))

    # add the test set as a np array for metrics calc
    if args.metrics_server is not None:
        network.test_images = get_numpy_dataset(task=args.task,
                                                data_dir=args.data_dir,
                                                test_transform=test_transform,
                                                split='test',
                                                image_size=args.image_size_override,
                                                cuda=args.cuda)
        print("Metrics test images: ", network.test_images.shape)

    # build the grapher object
    grapher = None
    if args.visdom_url is not None and args.distributed_rank == 0:
        grapher = Grapher('visdom', env=utils.get_name(args),
                          server=args.visdom_url,
                          port=args.visdom_port,
                          log_folder=args.log_dir)
    elif args.distributed_rank == 0:
        grapher = Grapher(
            'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args)))

    return loader, network, grapher