Exemplo n.º 1
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cpu'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    experiment_name = "test_{}".format(experiment_name)
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)
    # print(G)
    # print(D)
    # print(E)
    print("Model Created!")
    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights

    print('Loading weights...')
    utils.load_weights(
        G, D, E, state_dict, config['weights_root'],
        config['load_experiment_name'],
        config['load_weights'] if config['load_weights'] else None,
        G_ema if config['ema'] else None)
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }
    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': 0
    })

    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    G.eval()
    E.eval()
    print("check1 -------------------------------")
    print("state_dict['itr']", state_dict['itr'])
    if config['pbar'] == 'mine':
        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')

    else:
        pbar = tqdm(loaders[0])

    print("state_dict['itr']", state_dict['itr'])
    for i, (x, y) in enumerate(pbar):
        state_dict['itr'] += 1
        if config['D_fp16']:
            x, y = x.to(device).half(), y.to(device)
        else:
            x, y = x.to(device), y.to(device)
        print("x.shape", x.shape)
        print("y.shape", y.shape)

        activation_extract(G,
                           D,
                           E,
                           G_ema,
                           x,
                           y,
                           z_,
                           y_,
                           state_dict,
                           config,
                           experiment_name,
                           save_weights=False)
        if state_dict['itr'] == 20:
            break
Exemplo n.º 2
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, E, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # # Prepare inception metrics: FID and IS
    # get_inception_metrics = inception_utils.prepare_inception_metrics(
    #     config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE, ema, state_dict,
                                                config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning training at Epoch {} (iteration {})".format(
        state_dict['epoch'], state_dict['itr']))
    # # Train for specified number of epochs, although we mostly track G iterations.
    # for epoch in range(state_dict['epoch'], config['num_epochs']):
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
        pbar = tqdm(loaders[0])

    for i, (x, y) in enumerate(pbar):
        # Increment the iteration counter
        state_dict['itr'] += 1
        # Make sure G and D are in training mode, just in case they got set to eval
        # For D, which typically doesn't have BN, this shouldn't matter much.
        G.eval()
        D.eval()
        if config['ema']:
            G_ema.eval()
        if config['D_fp16']:
            x, y = x.to(device).half(), y.to(device)
        else:
            x, y = x.to(device), y.to(device)

        # Every sv_log_interval, log singular values
        if (config['sv_log_interval'] >
                0) and (not (state_dict['itr'] % config['sv_log_interval'])):
            train_log.log(itr=int(state_dict['itr']),
                          **{
                              **utils.get_SVs(G, 'G'),
                              **utils.get_SVs(D, 'D')
                          })

        # If using my progbar, print metrics.
        if config['pbar'] == 'mine':
            print(', '.join(
                ['itr: %d' % state_dict['itr']] +
                ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                  end=' ')

        # Save weights and copies as configured at specified interval
        if (not state_dict['itr'] % config['save_img_every']) or (
                not state_dict['itr'] % config['save_model_every']):
            if config['G_eval_mode']:
                print('Switchin G to eval mode...')
                G.eval()
                if config['ema']:
                    G_ema.eval()
            save_weights = config['save_weights']
            if state_dict['itr'] % config['save_model_every']:
                save_weights = False
            train_fns.save_and_sample(G,
                                      D,
                                      E,
                                      G_ema,
                                      fixed_x,
                                      fixed_y_of_x,
                                      z_,
                                      y_,
                                      state_dict,
                                      config,
                                      experiment_name,
                                      img_pool,
                                      save_weights=save_weights)

        # # Test every specified interval
        # if not (state_dict['itr'] % config['test_every']):
        #     if config['G_eval_mode']:
        #         print('Switchin G to eval mode...')
        #         G.eval()
        #     train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
        #                    get_inception_metrics, experiment_name, test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] = state_dict['itr'] * D_batch_size / (
            len(train_dataset))
        print("Finished Epoch {} (iteration {})".format(
            state_dict['epoch'], state_dict['itr']))
Exemplo n.º 3
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D, E]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(G, D, E, state_dict,
                           config['weights_root'], experiment_name,
                           config['load_weights'] if config['load_weights'] else None,
                           G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                        'start_itr': state_dict['itr']})

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                               device=device, fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                         config['n_classes'], device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset, experiment_name, config, device=device)
    

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE,
                                                ema, state_dict, config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(utils.sample,
                               G=(G_ema if config['ema'] and config['use_ema']
                                   else G),
                               z_=z_, y_=y_, config=config)




    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning testing at Epoch {} (iteration {})".format(state_dict['epoch'], state_dict['itr']))

    if config['G_eval_mode']:
        print('Switchin G to eval mode...')
        G.eval()
        if config['ema']:
            G_ema.eval()
    # vc visualization
    # # print("VC visualization ===============")
    # activation_extract(G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_,
    #                             state_dict, config, experiment_name, device, normal_eval=False, eval_vc=True, return_mask=False)
    # normal activation
    print("Normal activation ===============")
    activation_extract(G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_,
                                state_dict, config, experiment_name, device, normal_eval=True, eval_vc=False, return_mask=False) # produce normal fully activated images
Exemplo n.º 4
0
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    GDE = model.G_D_E(G, D, E)

    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))

    print('Loading weights...')
    utils.load_weights(
        G,
        D,
        E,
        state_dict,
        config['weights_root'],
        experiment_name,
        config['load_weights'] if config['load_weights'] else None,
        None,
        strict=False,
        load_optim=False)

    # ==============================================================================
    # prepare the data
    loaders, train_dataset = utils.get_data_loaders(**config)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])

    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device)

    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device)

    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))

    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    evaluate_sample(config,
                    fixed_x,
                    fixed_y,
                    G,
                    E,
                    experiment_name,
                    attack=True)