Beispiel #1
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, E, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE, ema, state_dict,
                                                config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning testing at Epoch {} (iteration {})".format(
        state_dict['epoch'], state_dict['itr']))

    if config['G_eval_mode']:
        print('Switchin G to eval mode...')
        G.eval()
        if config['ema']:
            G_ema.eval()
    fixed_x, fixed_Gz, intermediates = activation_extract(
        G,
        D,
        E,
        G_ema,
        fixed_x,
        fixed_y_of_x,
        z_,
        y_,
        state_dict,
        config,
        experiment_name,
        save_weights=config['save_weights'])

    plot_channel_activation(intermediates, config['img_index'],
                            fixed_Gz[config['img_index']], experiment_name,
                            config, state_dict)
Beispiel #2
0
    np.random.seed(args.seed)
    random.seed(args.seed)

    # get training and validation data loaders
    normalize = None
    class_offset = 0
    performance_eval = eval_loss_acc1_acc5
    if args.net == 'mobilenet-imagenet':
        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        class_offset = 1
    if args.dataset == 'cityscapes':
        performance_eval = eval_loss_iou

    tr_loader, val_loader, train_loader4eval = get_data_loaders(data_dir=args.datadir,
                                                                dataset=args.dataset,
                                                                batch_size=args.batch_size,
                                                                val_batch_size=args.val_batch_size,
                                                                num_workers=args.num_workers,
                                                                normalize=normalize)
    # get network model
    model, teacher_model = get_net_model(net=args.net, pretrained_dataset=args.dataset, dropout=False,
                                         pretrained=not args.randinit)

    # pretrained model
    if args.pretrain is not None and os.path.isfile(args.pretrain):
        print('load pretrained model:{}'.format(args.pretrain))
        model.load_state_dict(torch.load(args.pretrain))
    elif args.pretrain is not None:
        print('fail to load pretrained model: {}'.format(args.pretrain))

    # set up multi-gpus
    if args.mgpu:
Beispiel #3
0
def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        print('\nStarting epoch {}/{}'.format(epoch,
                                              train_options.number_of_epochs))
        print('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        losses_accu = {}
        epoch_start = time.time()
        step = 1
        for image, _ in train_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])
            if not losses_accu:  # dict is empty, initialize
                for name in losses:
                    losses_accu[name] = []

            for name, loss in losses.items():
                losses_accu[name].append(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                print('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.print_progress(losses_accu)
                print('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        print('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        print('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           losses_accu, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(losses_accu, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True

        print('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        for image, _ in val_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (encoded_images, noised_images,
                     decoded_messages) = model.validate_on_batch(
                         [image, message])
            if not losses_accu:  # dict is empty, initialize
                for name in losses:
                    losses_accu[name] = []
            for name, loss in losses.items():
                losses_accu[name].append(loss)
            if first_iteration:
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False

        utils.print_progress(losses_accu)
        print('-' * 40)
        utils.save_checkpoint(model, epoch, losses_accu,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           losses_accu, epoch,
                           time.time() - epoch_start)
Beispiel #4
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    experiment_name = "test_{}".format(experiment_name)
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)
    # print(G)
    # print(D)
    # print(E)
    print("Model Created!")
    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights

    print('Loading weights...')
    utils.load_weights(
        G, D, E, state_dict, config['weights_root'],
        config['load_experiment_name'],
        config['load_weights'] if config['load_weights'] else None,
        G_ema if config['ema'] else None)
    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    G.eval()
    E.eval()
    print("check1 -------------------------------")
    activation_extract(G,
                       D,
                       E,
                       G_ema,
                       fixed_x,
                       fixed_y_of_x,
                       z_,
                       y_,
                       state_dict,
                       config,
                       experiment_name,
                       save_weights=False)
    os.makedirs(args.logs_folder)
if not os.path.exists(args.plots_folder):
    os.makedirs(args.plots_folder)
if not os.path.exists(args.trained_models_folder):
    os.makedirs(args.trained_models_folder)
log_file = get_logfilename_with_datetime('train-log')
logging.basicConfig(filename=join(args.logs_folder, log_file),
                    level=logging.INFO,
                    filemode='w',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    format='%(asctime)s - %(levelname)s - %(message)s')

print(args)
logging.info(args)

dataloaders, dataset_sizes, class_names = get_data_loaders(
    args.train_data, args.batch_size)
logging.info("Train size {}, Val size {}, Test size {}".format(
    dataset_sizes['train'], dataset_sizes['val'], dataset_sizes['test']))
logging.info('Class names:{}'.format(class_names))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    total_gpus = torch.cuda.device_count()
    logging.info('Total number of GPUs:{}'.format(total_gpus))
    if total_gpus == 1:
        multi_gpu = False
    elif total_gpus > 1:
        multi_gpu = True

else:
    print("No GPUs, Cannot proceed. This training regime needs GPUs.")
    exit(1)
def run(config):
    def len_parallelloader(self):
        return len(self._loader._loader)
    pl.PerDeviceLoader.__len__ = len_parallelloader

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        xm.master_print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different
    # files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    xm.master_print('Experiment name is %s' % experiment_name)

    device = xm.xla_device(devkind='TPU')

    # Next, build the model
    G = model.Generator(**config)
    D = model.Discriminator(**config)

    # If using EMA, prepare it
    if config['ema']:
        xm.master_print(
            'Preparing EMA for G with decay of {}'.format(
                config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True})
    else:
        xm.master_print('Not using ema...')
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        xm.master_print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        xm.master_print('Casting D to fp16...')
        D = D.half()

    # Prepare state dict, which holds things like itr #
    state_dict = {'itr': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        xm.master_print('Loading weights...')
        utils.load_weights(
            G,
            D,
            state_dict,
            config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # move everything to TPU
    G = G.to(device)
    D = D.to(device)

    G.optim = optim.Adam(params=G.parameters(), lr=G.lr,
                         betas=(G.B1, G.B2), weight_decay=0,
                         eps=G.adam_eps)
    D.optim = optim.Adam(params=D.parameters(), lr=D.lr,
                         betas=(D.B1, D.B2), weight_decay=0,
                         eps=D.adam_eps)

    # for key, val in G.optim.state.items():
    #  G.optim.state[key]['exp_avg'] = G.optim.state[key]['exp_avg'].to(device)
    #  G.optim.state[key]['exp_avg_sq'] = G.optim.state[key]['exp_avg_sq'].to(device)

    # for key, val in D.optim.state.items():
    #  D.optim.state[key]['exp_avg'] = D.optim.state[key]['exp_avg'].to(device)
    #  D.optim.state[key]['exp_avg_sq'] = D.optim.state[key]['exp_avg_sq'].to(device)

    if config['ema']:
        G_ema = G_ema.to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])

    # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    xm.master_print(G)
    xm.master_print(D)
    xm.master_print('Number of params in G: {} D: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]]))

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    xm.master_print(
        'Test Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    xm.master_print(
        'Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])

    if xm.is_master_ordinal():
            # Write metadata
        utils.write_metadata(
            config['logs_root'],
            experiment_name,
            config,
            state_dict)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    xm.master_print('Preparing data...')
    loader = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                       'start_itr': state_dict['itr']})

    # Prepare inception metrics: FID and IS
    xm.master_print('Preparing metrics...')

    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'],
        no_inception=config['no_inception'],
        no_fid=config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])

    def sample(): return utils.prepare_z_y(G_batch_size, G.dim_z,
                                           config['n_classes'], device=device,
                                           fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout
    # training
    fixed_z, fixed_y = sample()

    train = train_fns.GAN_training_function(G, D, GD, sample, ema, state_dict,
                                            config)

    xm.master_print('Beginning training...')

    if xm.is_master_ordinal():
        pbar = tqdm(total=config['total_steps'])
        pbar.n = state_dict['itr']
        pbar.refresh()

    xm.rendezvous('training_starts')
    while (state_dict['itr'] < config['total_steps']):
        pl_loader = pl.ParallelLoader(
            loader, [device]).per_device_loader(device)

        for i, (x, y) in enumerate(pl_loader):
            if xm.is_master_ordinal():
                # Increment the iteration counter
                pbar.update(1)

            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter
            # much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()

            xm.rendezvous('data_collection')
            metrics = train(x, y)

            # train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if ((config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval']))) :
                if xm.is_master_ordinal():
                    train_log.log(itr=int(state_dict['itr']),
                        **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
                xm.rendezvous('Log SVs.')

            # Save weights and copies as configured at specified interval
            if (not (state_dict['itr'] % config['save_every'])):
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(
                    G,
                    D,
                    G_ema,
                    sample,
                    fixed_z,
                    fixed_y,
                    state_dict,
                    config,
                    experiment_name)

            # Test every specified interval
            if (not (state_dict['itr'] % config['test_every'])):

                which_G = G_ema if config['ema'] and config['use_ema'] else G
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    which_G.eval()

                def G_sample():
                    z, y = sample()
                    return which_G(z, which_G.shared(y))

                train_fns.test(
                    G,
                    D,
                    G_ema,
                    sample,
                    state_dict,
                    config,
                    G_sample,
                    get_inception_metrics,
                    experiment_name,
                    test_log)
            
            # Debug : Message print
            # if True:
            #     xm.master_print(met.metrics_report())

            if state_dict['itr'] >= config['total_steps']:
                break
Beispiel #7
0
def test(G, D, GD, G_ema, z_, y_, state_dict, config, sample,
         get_inception_metrics, experiment_name, test_log, acc_metrics,
         acc_itrs):
    print('Calculating validation accuracy...')
    D.eval()
    D_accuracy = []
    loader = utils.get_data_loaders(
        **{
            **config, 'train': False,
            'use_multiepoch_sampler': False,
            'load_in_mem': False
        })[0]
    with torch.no_grad():
        for x, y in loader:
            D_real = GD(None, None, x=x, dy=y, policy=config['DiffAugment'])
            D_accuracy.append((D_real > 0).float().mean().item())
    D.train()
    D_acc_val = np.mean(D_accuracy)

    print('Calculating training accuracy...')
    D.eval()
    D_accuracy = []
    loader = utils.get_data_loaders(
        **{
            **config, 'train': True,
            'use_multiepoch_sampler': False,
            'load_in_mem': False
        })[0]
    with torch.no_grad():
        for x, y in loader:
            D_real = GD(None, None, x=x, dy=y, policy=config['DiffAugment'])
            D_accuracy.append((D_real > 0).float().mean().item())
    D.train()
    D_acc_train = np.mean(D_accuracy)

    print('Gathering inception metrics...')
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_,
            config['n_classes'], config['num_standing_accumulations'])
    IS_mean, IS_std, FID = get_inception_metrics(
        sample, config['num_inception_images'], num_splits=10)
    print(
        'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f'
        % (state_dict['itr'], IS_mean, IS_std, FID))
    # If improved over previous best metric, save approrpiate copy
    if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) or
        (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
        print('%s improved over previous best, saving checkpoint...' %
              config['which_best'])
        utils.save_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            'best%d' % state_dict['save_best_num']
            if config['num_best_copies'] > 1 else 'best',
            G_ema if config['ema'] else None)
        state_dict['save_best_num'] = (state_dict['save_best_num'] +
                                       1) % config['num_best_copies']
    state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
    state_dict['best_FID'] = min(state_dict['best_FID'], FID)

    # Log results to file
    test_log.log(itr=int(state_dict['itr']),
                 IS_mean=float(IS_mean),
                 IS_std=float(IS_std),
                 FID=float(FID),
                 D_acc_val=D_acc_val,
                 D_acc_train=D_acc_train,
                 **{k: v / acc_itrs
                    for k, v in acc_metrics.items()})
Beispiel #8
0
def main():
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    parser = argparse.ArgumentParser(description='Training of HiDDeN nets')
    parser.add_argument('--hostname',
                        default=socket.gethostname(),
                        help='the  host name of the running server')
    # parser.add_argument('--size', '-s', default=128, type=int, help='The size of the images (images are square so this is height and width).')
    parser.add_argument('--data-dir',
                        '-d',
                        required=True,
                        type=str,
                        help='The directory where the data is stored.')
    parser.add_argument(
        '--runs_root',
        '-r',
        default=os.path.join('.', 'experiments'),
        type=str,
        help='The root folder where data about experiments are stored.')
    parser.add_argument('--batch-size',
                        '-b',
                        default=1,
                        type=int,
                        help='Validation batch size.')

    args = parser.parse_args()

    if args.hostname == 'ee898-System-Product-Name':
        args.data_dir = '/home/ee898/Desktop/chaoning/ImageNet'
        args.hostname = 'ee898'
    elif args.hostname == 'DL178':
        args.data_dir = '/media/user/SSD1TB-2/ImageNet'
    else:
        args.data_dir = '/workspace/data_local/imagenet_pytorch'
    assert args.data_dir

    print_each = 25

    completed_runs = [
        o for o in os.listdir(args.runs_root)
        if os.path.isdir(os.path.join(args.runs_root, o))
        and o != 'no-noise-defaults'
    ]

    print(completed_runs)

    write_csv_header = True
    current_run = args.runs_root
    print(f'Run folder: {current_run}')
    options_file = os.path.join(current_run, 'options-and-config.pickle')
    train_options, hidden_config, noise_config = utils.load_options(
        options_file)
    train_options.train_folder = os.path.join(args.data_dir, 'val')
    train_options.validation_folder = os.path.join(args.data_dir, 'val')
    train_options.batch_size = args.batch_size
    checkpoint, chpt_file_name = utils.load_last_checkpoint(
        os.path.join(current_run, 'checkpoints'))
    print(f'Loaded checkpoint from file {chpt_file_name}')

    noiser = Noiser(noise_config, device, 'jpeg')
    model = Hidden(hidden_config, device, noiser, tb_logger=None)
    utils.model_from_checkpoint(model, checkpoint)

    print('Model loaded successfully. Starting validation run...')
    _, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(val_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    with torch.no_grad():
        noises = [
            'jpeg2000_100', 'jpeg2000_250', 'jpeg2000_500', 'jpeg2000_750',
            'jpeg2000_900'
        ]
        for noise in noises:
            losses_accu = {}
            step = 0
            for image, _ in val_data:
                step += 1
                image = image.to(device)
                message = torch.Tensor(
                    np.random.choice(
                        [0, 1], (image.shape[0],
                                 hidden_config.message_length))).to(device)
                losses, (
                    encoded_images, noised_images,
                    decoded_messages) = model.validate_on_batch_specific_noise(
                        [image, message], noise=noise)
                if not losses_accu:  # dict is empty, initialize
                    for name in losses:
                        losses_accu[name] = AverageMeter()
                for name, loss in losses.items():
                    losses_accu[name].update(loss)
                if step % print_each == 0 or step == steps_in_epoch:
                    print(f'Step {step}/{steps_in_epoch}')
                    utils.print_progress(losses_accu)
                    print('-' * 40)

            # utils.print_progress(losses_accu)
            write_validation_loss(os.path.join(args.runs_root,
                                               'validation_run.csv'),
                                  losses_accu,
                                  noise,
                                  checkpoint['epoch'],
                                  write_header=write_csv_header)
            write_csv_header = False
def main():
    train_dataloader, test_dataloader = get_data_loaders(opt)
    capsule_net = load_network(opt)

    criterion = nn.MSELoss(reduction='mean').cuda()
    optimizer = optim.Adam(capsule_net.parameters(), lr=0.0001)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    capsule_net.train()

    if opt.load_weights != '':
        print('Load weights from', opt.load_weights)
        capsule_net.module.load_state_dict(torch.load(opt.load_weights))

    # =============== TRAIN =========================
    global_train_step = 0
    global_test_step = 0
    best_map = 0
    for i, epoch in enumerate(range(opt.nepoch)):
        train_loss_sum = 0
        true_values_per_joint_per_threshold = np.zeros(
            (15, thresholds.shape[0]))
        total_examples = 0
        capsule_net.train()
        print(
            f'======>>>>> Online epoch: {epoch}, lr={get_lr(optimizer)} <<<<<======'
        )
        for data in tqdm(train_dataloader):
            global_train_step += 1
            points, coords, mean, maxv = data
            points, coords = points.cuda(non_blocking=True), coords.cuda(
                non_blocking=True)
            if points.size(0) < opt.batchSize:
                break

            optimizer.zero_grad()
            estimation, reconstructions = capsule_net(points)

            dist1, dist2 = chamfer_dist(reconstructions, points)
            reconstruction_loss = (torch.mean(dist1)) + (torch.mean(dist2))
            regression_loss = criterion(estimation, coords)
            total_loss = reconstruction_loss + regression_loss
            total_loss.backward()
            optimizer.step()
            torch.cuda.synchronize()

            writer.add_scalar('train/reconstruction-loss',
                              reconstruction_loss.detach().cpu().numpy(),
                              global_step=global_train_step)
            writer.add_scalar('train/regression-loss',
                              regression_loss.detach().cpu().numpy(),
                              global_step=global_train_step)

            points_to_display = np.expand_dims(points.cpu().numpy()[0], axis=0)
            reconstructions_to_display = np.expand_dims(
                reconstructions.detach().cpu().numpy()[0], axis=0)
            points_to_display = ITOPDataset.denormalize(
                points_to_display, mean[0].numpy(), maxv[0].numpy())
            reconstructions_to_display = ITOPDataset.denormalize(
                reconstructions_to_display, mean[0].numpy(), maxv[0].numpy())

            writer.add_mesh('train/points',
                            vertices=points_to_display,
                            global_step=global_train_step)
            writer.add_mesh('train/points-reconstruction',
                            vertices=reconstructions_to_display,
                            global_step=global_train_step)
            writer.add_scalar('train/lr',
                              get_lr(optimizer),
                              global_step=global_train_step)

            # mAP calculating
            total_examples += len(estimation)
            estimation = ITOPDataset.denormalize(
                estimation.detach().cpu().numpy(), mean.numpy(), maxv.numpy())
            coords = ITOPDataset.denormalize(coords.detach().cpu().numpy(),
                                             mean.numpy(), maxv.numpy())
            batch_diff = np.linalg.norm(estimation - coords,
                                        axis=2)  # N x JOINT_SIZE
            for example in batch_diff:
                for i, joint_diff in enumerate(example):
                    true_values_per_joint_per_threshold[i] += (
                        joint_diff < thresholds).astype(int)

            train_loss_sum += total_loss.item()

        scheduler.step()
        torch.cuda.synchronize()

        map_fig, map_01 = build_map_plot(
            thresholds, true_values_per_joint_per_threshold / total_examples)
        writer.add_figure('train/map', map_fig, global_step=global_train_step)

        # =============== EVAL =========================

        test_reconstruction_loss_sum = 0
        test_regression_loss_sum = 0
        true_values_per_joint_per_threshold = np.zeros(
            (15, thresholds.shape[0]))
        total_examples = 0
        for i, data in enumerate(tqdm(test_dataloader)):
            global_test_step += 1
            capsule_net.eval()
            points, coords, mean, maxv = data
            points, coords = points.cuda(), coords.cuda()

            estimation, reconstructions = capsule_net(points)
            dist1, dist2 = chamfer_dist(points, reconstructions)
            test_reconstruction_loss = (torch.mean(dist1)) + (
                torch.mean(dist2))
            test_regression_loss = criterion(estimation, coords)
            test_reconstruction_loss_sum += test_reconstruction_loss.item()
            test_regression_loss_sum += test_regression_loss.item()

            points_to_display = np.expand_dims(points.cpu().numpy()[0], axis=0)
            reconstructions_to_display = np.expand_dims(
                reconstructions.detach().cpu().numpy()[0], axis=0)
            points_to_display = ITOPDataset.denormalize(
                points_to_display, mean[0].numpy(), maxv[0].numpy())
            reconstructions_to_display = ITOPDataset.denormalize(
                reconstructions_to_display, mean[0].numpy(), maxv[0].numpy())
            writer.add_mesh('test/points',
                            vertices=points_to_display,
                            global_step=global_test_step)
            writer.add_mesh('test/points-reconstruction',
                            vertices=reconstructions_to_display,
                            global_step=global_test_step)

            # -------- mAP calculating
            total_examples += len(estimation)

            estimation = ITOPDataset.denormalize(
                estimation.detach().cpu().numpy(), mean.numpy(), maxv.numpy())
            coords = ITOPDataset.denormalize(coords.detach().cpu().numpy(),
                                             mean.numpy(), maxv.numpy())
            batch_diff = np.linalg.norm(estimation - coords,
                                        axis=2)  # N x JOINT_SIZE
            for example in batch_diff:
                for i, joint_diff in enumerate(example):
                    true_values_per_joint_per_threshold[i] += (
                        joint_diff < thresholds).astype(int)

        avg_reconstruction = test_reconstruction_loss_sum / len(
            test_dataloader)
        avg_regression = test_regression_loss_sum / len(test_dataloader)
        writer.add_scalar('test/reconstruction-loss',
                          avg_reconstruction,
                          global_step=global_test_step)
        writer.add_scalar('test/regression-loss',
                          avg_regression,
                          global_step=global_test_step)

        map_fig, map_01 = build_map_plot(
            thresholds, true_values_per_joint_per_threshold / total_examples)
        writer.add_figure('test/map', map_fig, global_step=global_test_step)

        if best_map < map_01:
            best_map = map_01
            torch.save(
                capsule_net.module.state_dict(),
                f'{save_dir}/{epoch:03}-{best_map:0.3}-capsule_net-module.pth')
def prepare_inception_metrics(dataset, parallel, config):
    dataset = dataset.strip('_hdf5')
    dnnlib.tflib.init_tf()
    inception_v3_features = dnnlib.util.load_pkl(
        'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/inception_v3_features.pkl'
    )
    inception_v3_softmax = dnnlib.util.load_pkl(
        'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/inception_v3_softmax.pkl'
    )

    print('Calculating inception features for the training set...')
    loader = utils.get_data_loaders(
        **{
            **config, 'train': False,
            'mirror_augment': False,
            'use_multiepoch_sampler': False,
            'load_in_mem': False,
            'pin_memory': False
        })[0]
    pool = []
    cnt = 0
    num_gpus = torch.cuda.device_count()
    # Create directorty for training images
    folder = '%s/%s' % (config['samples_root'], config['experiment_name'])
    training_dir = os.path.join(folder, 'training_images')
    os.makedirs(training_dir, exist_ok=True)
    print('Saving training images at %s' % training_dir)
    for images, _ in loader:
        images = (
            (images.permute(0, 2, 3, 1).cpu().numpy() * 0.5 + 0.5) * 255 +
            0.5).astype(np.uint8)
        for img_idx, img in enumerate(images):
            file_idx = 'train_image_{:05d}.png'.format(cnt + img_idx)
            file_name = os.path.join(training_dir, file_idx)
            imsave(file_name, img)
        images = np.transpose(images, (0, 3, 1, 2))
        pool.append(
            inception_v3_features.run(images,
                                      num_gpus=num_gpus,
                                      assume_frozen=True))
        cnt += images.shape[0]

    pool = np.concatenate(pool)
    mu_real, sigma_real = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
    dnnlib.util.save_pkl((mu_real, sigma_real),
                         dataset + '_inception_moments.pkl')
    mu_real, sigma_real = dnnlib.util.load_pkl(dataset +
                                               '_inception_moments.pkl')

    def get_inception_metrics(sample,
                              num_inception_images,
                              folder_fake,
                              num_splits=10,
                              prints=True,
                              use_torch=True):
        pool, logits = accumulate_inception_activations(
            sample, inception_v3_features, inception_v3_softmax,
            num_inception_images, folder_fake)
        IS_mean, IS_std = calculate_inception_score(logits, num_splits)
        mu_fake, sigma_fake = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        FID = np.real(dist)
        return IS_mean, IS_std, FID, sigma_fake, pool

    return get_inception_metrics, sigma_real, pool
Beispiel #11
0
def train_test(depth, width, which_dataset, seed, augment, validate, test,
               validate_seed, validate_split, epochs, weights_fname,
               batch_size, resume, model, fp16, parallel, validate_every,
               duplicate_at_checkpoints, fold, top5):

    # Number of classes
    nClasses = nClass_dict[which_dataset]

    # Seed RNG
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

    # Create logs and weights folder if they don't exist
    if not os.path.exists('logs'):
        os.mkdir('logs')
    if not os.path.exists('weights'):
        os.mkdir('weights')

    # Name of the file to which we're saving losses and errors.
    if weights_fname is 'default_save':
        weights_fname = '_'.join([
            item for item in [
                model, 'D' + str(depth), 'K' +
                str(width), 'fp16' if fp16 else None, which_dataset, 'seed' +
                str(seed), 'val' if validate else None,
                str(epochs) + 'epochs'
            ] if item is not None
        ])

    # Prepare metrics logging
    metrics_fname = 'logs/' + weights_fname + '_log.jsonl'
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s %(levelname)s| %(message)s')
    logging.info('Metrics will be saved to {}'.format(metrics_fname))
    mlog = MetricsLogger(metrics_fname, reinitialize=(not resume))

    # Import the model module
    sys.path.append('models')
    model_module = __import__(model)

    # Build network, either by initializing it or re-lodaing.
    if resume:
        logging.info('loading network ' + weights_fname + '...')
        net = torch.load('weights/' + weights_fname + '.pth')

        if parallel:
            net = torch.nn.DataParallel(net)
        if fp16:
            net = net.half()

        # Which epoch we're starting from
        start_epoch = net.epoch + 1 if hasattr(net, 'epoch') else 0

        # Rescale iteration counter if batchsize requires it.
        if hasattr(net, 'j'):
            net.j = int(net.j * net.batch_size / float(batch_size))
            net.batch_size = batch_size

    else:
        net = model_module.Network(width,
                                   depth,
                                   nClasses=nClasses,
                                   epochs=epochs)
        net = net.cuda()
        net.batch_size = batch_size

        if fp16:
            net = net.half()
        if parallel:
            net = torch.nn.DataParallel(net)
            net.lr_sched = net.module.lr_sched
            net.update_lr = net.module.update_lr
            net.optim = net.module.optim

        start_epoch = 0

    logging.info('Number of params: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))
    # Get information specific to each dataset
    loaders = get_data_loaders(which_dataset, augment, validate, test,
                               batch_size, fold, validate_seed,
                               validate_split / 100.)
    train_loader = loaders[0]

    if validate:
        val_loader = loaders[1]
    if test:
        test_loader = loaders[2]
    # Training Function, presently only returns training loss
    # x: input data
    # y: target labels
    def train_fn(x, y):
        net.optim.zero_grad()
        input = V(x.cuda().half()) if fp16 else V(x.cuda())

        output = net(input)

        loss = F.nll_loss(output, V(y.cuda()))
        training_error = output.data.max(1)[1].cpu().ne(y).sum()

        loss.backward()
        net.optim.step()
        return loss.data[0], training_error

    # Testing function, returns test loss and test error for a batch
    # x: input data
    # y: target labels
    def test_fn(x, y):

        input = V(x.cuda().half(), volatile=True) if fp16 else V(x.cuda(),
                                                                 volatile=True)
        output = net(input)
        test_loss = F.nll_loss(output, V(y.cuda(), volatile=True)).data[0]

        # If we're running Imagenet, we may want top-5 error:
        if top5:
            top5_preds = np.argsort(output.data.cpu().numpy())[:, :-6:-1]
            test_error = len(y) - np.sum(
                [np.any(top5_i == y_i) for top5_i, y_i in zip(top5_preds, y)])
        else:
            # Get the index of the max log-probability as the prediction.
            pred = output.data.max(1)[1].cpu()
            test_error = pred.ne(y).sum()

        return test_loss, test_error

    # Finally, launch the training loop.
    logging.info('Starting training at epoch ' + str(start_epoch) + '...')
    for epoch in range(start_epoch, epochs):

        # Pin the current epoch on the network.
        net.epoch = epoch

        # shrink learning rate at scheduled intervals, if desired
        if 'epoch' in net.lr_sched and epoch in net.lr_sched['epoch']:

            logging.info('Annealing learning rate...')

            # Optionally checkpoint at annealing
            if net.checkpoint_before_anneal:
                torch.save(
                    net,
                    'weights/' + str(epoch) + '_' + weights_fname + '.pth')

            for param_group in net.optim.param_groups:
                param_group['lr'] *= 0.1

        # List where we'll store training loss
        train_loss, train_err = [], []

        # Prepare the training data
        batches = progress(train_loader,
                           desc='Epoch %d/%d, Batch ' % (epoch + 1, epochs),
                           total=len(train_loader.dataset) // batch_size)

        # Put the network into training mode
        net.train()

        # Execute training pass
        for x, y in batches:

            # Update LR if using cosine annealing
            if 'itr' in net.lr_sched:
                net.update_lr(max_j=epochs * len(train_loader.dataset) //
                              batch_size)

            loss, err = train_fn(x, y)
            train_loss.append(loss)
            train_err.append(err)

        # Report training metrics
        train_loss = float(np.mean(train_loss))
        train_err = 100 * float(np.sum(train_err)) / len(train_loader.dataset)
        print('  training loss:\t%.6f, training error: \t%.2f%%' %
              (train_loss, train_err))
        mlog.log(epoch=epoch, train_loss=train_loss, train_err=train_err)

        # Optionally, take a pass over the validation set.
        if validate and not ((epoch + 1) % validate_every):

            # Lists to store
            val_loss, val_err = [], []

            # Set network into evaluation mode
            net.eval()

            # Execute validation pass
            for x, y in tqdm(val_loader):
                loss, err = test_fn(x, y)
                val_loss.append(loss)
                val_err.append(err)

            # Report validation metrics
            val_loss = float(np.mean(val_loss))
            val_err = 100 * float(np.sum(val_err)) / len(val_loader.dataset)
            print('  validation loss:\t%.6f,  validation error:\t%.2f%%' %
                  (val_loss, val_err))
            mlog.log(epoch=epoch, val_loss=val_loss, val_err=val_err)

        # Optionally, take a pass over the validation or test set.
        if test and not ((epoch + 1) % validate_every):

            # Lists to store
            test_loss, test_err = [], []

            # Set network into evaluation mode
            net.eval()

            # Execute validation pass
            for x, y in tqdm(test_loader):
                loss, err = test_fn(x, y)
                test_loss.append(loss)
                test_err.append(err)

            # Report validation metrics
            test_loss = float(np.mean(test_loss))
            test_err = 100 * float(np.sum(test_err)) / len(test_loader.dataset)
            print('  test loss:\t%.6f,  test error:\t%.2f%%' %
                  (test_loss, test_err))
            mlog.log(epoch=epoch, test_loss=test_loss, test_err=test_err)
        # Save weights for this epoch
        print('saving weights to ' + weights_fname + '...')
        torch.save(net, 'weights/' + weights_fname + '.pth')

        # If requested, save a checkpointed copy with a different name
        # so that we have them for reference later.
        if duplicate_at_checkpoints and not epoch % 5:
            torch.save(net,
                       'weights/' + weights_fname + '_e' + str(epoch) + '.pth')

    # At the end of it all, save weights even if we didn't checkpoint.
    if weights_fname:
        torch.save(net, 'weights/' + weights_fname + '.pth')
Beispiel #12
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--no_persona",
                        type=bool,
                        default=False,
                        help="keep persona in training or not ")
    args = parser.parse_args([
        '--gradient_accumulation_steps',
        '4',
        '--lm_coef',
        '2.0',
        '--max_history',
        '2',
        '--n_epochs',
        '1',
        '--num_candidates',
        '4',
        '--personality_permutations',
        '2',
        '--train_batch_size',
        '2',
        '--valid_batch_size',
        '2',
        '--fp16',
        'O1',
        # '--lr', '1e-5',
        # '--model_checkpoint', '/ssd/siqi/CSR/GPT/pretrained/117M-reddit',
        # '--model_checkpoint', '/datadrive/ssd/117M-reddit',
        '--model_checkpoint',
        '/datadrive/ssd/117M',
        '--no_persona',
        'True'
    ])

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer  # if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2DoubleHeadsModel  # if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(args.model_checkpoint)
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer, no_persona=args.no_persona)

    # Training function and trainer
    update = partial(update_full, args=args, model=model, optimizer=optimizer)
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    inference = partial(inference_full,
                        args=args,
                        model=model,
                        tokenizer=tokenizer)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #13
0
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels.float())
            train_loss += loss * (len(images) / float(num_train_examples))
            predicted = (outputs.squeeze() > 0.).long()

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return model.state_dict()


if __name__ == '__main__':
    from utils import get_data_loaders
    # hyperparams
    batch_size = 64
    num_epochs = 100
    learning_rate = 1e-1
    lmbda = 5e-3
    data_seed = 0
    num_train = 1000

    # load data
    loaders, _ = get_data_loaders(data_seed, batch_size, num_train)

    # train model
    nonprivate_params = \
            nonprivate_logistic_regression(loaders['train'], num_epochs,
                    learning_rate, lmbda)
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    GDE = model.G_D_E(G, D, E)

    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))

    print('Loading weights...')
    utils.load_weights(
        G,
        D,
        E,
        state_dict,
        config['weights_root'],
        experiment_name,
        config['load_weights'] if config['load_weights'] else None,
        None,
        strict=False,
        load_optim=False)

    # ==============================================================================
    # prepare the data
    loaders, train_dataset = utils.get_data_loaders(**config)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])

    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device)

    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device)

    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))

    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    evaluate_sample(config,
                    fixed_x,
                    fixed_y,
                    G,
                    E,
                    experiment_name,
                    attack=True)
def main(args):
    # Load a pre-defined tokenizer (GPT-2), create config and model
    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )

    gpt_tokenizer = GPT2Tokenizer.from_pretrained(args.qgen_model_path,
                                                  cache_dir=args.dataset_cache)
    gpt_tokenizer.add_tokens(SPECIAL_TOKENS)
    gpt_tokenizer.sep_token = '<sep>'

    qgen = GPT2ConditionalLMHeadModel.from_pretrained(
        args.qgen_model_path, cache_dir=args.dataset_cache)
    qgen.resize_token_embeddings(len(gpt_tokenizer))
    qgen.to(args.device)
    qgen_optimizer = AdamW(qgen.parameters(),
                           lr=args.learning_rate,
                           eps=args.adam_epsilon)

    bos, eos, ctx, ans, que, pad = gpt_tokenizer.convert_tokens_to_ids(
        SPECIAL_TOKENS)

    if args.n_gpu > 1:
        qgen = torch.nn.DataParallel(qgen)

    logger.info("Prepare datasets")
    dataloader = get_data_loaders(args, gpt_tokenizer)

    # Define training function
    def update(engine, batch):
        # remove extra pad from batches
        batch = trim_batch(batch, pad)
        qgen.train()

        loss = 0.0
        ###################################
        # MLE training with teacher forcing
        ###################################
        if 'sl' in args.learning:
            input_ids, lm_labels, token_type_ids, _, _, _ = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            loss_ce = qgen(input_ids=input_ids,
                           labels=lm_labels,
                           token_type_ids=token_type_ids)[0]
            loss = apply_loss(engine.state.iteration, qgen_optimizer, loss_ce,
                              args)
        return loss.item()

    trainer = Engine(update)

    # Add progressbar with loss
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    ProgressBar(persist=True).attach(trainer, metric_names=['loss'])

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(qgen_optimizer, "lr",
                                [(0, args.learning_rate),
                                 (args.n_epochs * len(dataloader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Save checkpoints
    checkpoint_handler = ModelCheckpoint(args.checkpoint,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=6,
                                         require_empty=False)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED, checkpoint_handler,
        {'mymodel': getattr(qgen, 'module', qgen)
         })  # "getattr" take care of distributed encapsulation

    # save training config
    torch.save(dict(args), os.path.join(args.checkpoint, 'training_args.bin'))
    getattr(qgen, 'module', qgen).config.to_json_file(
        os.path.join(args.checkpoint, CONFIG_NAME))
    gpt_tokenizer.save_vocabulary(args.checkpoint)

    trainer.run(dataloader, max_epochs=args.n_epochs)
Beispiel #16
0
def run(config):

    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    model = BigGAN

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it (Earth Moving Averaging for parameters)
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        ema = None

    GD = model.G_D(G, D)

    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))

    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict,
                                            config)

    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):

        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)

            print(', '.join(
                ['itr: %d' % state_dict['itr']] +
                ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                  end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Beispiel #17
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # In some cases we need to load D
    if True or config['get_test_error'] or config['get_train_error'] or config[
            'get_self_error'] or config['get_generator_error']:
        disc_config = config.copy()
        if config['mh_csc_loss'] or config['mh_loss']:
            disc_config['output_dim'] = disc_config['n_classes'] + 1
        D = model.Discriminator(**disc_config).to(device)

        def get_n_correct_from_D(x, y):
            """Gets the "classifications" from D.
      
      y: the correct labels
      
      In the case of projection discrimination we have to pass in all the labels
      as conditionings to get the class specific affinity.
      """
            x = x.to(device)
            if config['model'] == 'BigGAN':  # projection discrimination case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x, y)
                for i in range(1, config['n_classes']):
                    yhat_ = D(x, ((y + i) % config['n_classes']))
                    yhat = torch.cat([yhat, yhat_], 1)
                preds_ = yhat.data.max(1)[1].cpu()
                return preds_.eq(0).cpu().sum()
            else:  # the mh gan case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x)
                preds_ = yhat[:, :config['n_classes']].data.max(1)[1]
                return preds_.eq(y.data).cpu().sum()

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       D,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    brief_expt_name = config['experiment_name'][-30:]

    # load results dict always
    HIST_FNAME = 'scoring_hist.npy'

    def load_or_make_hist(d):
        """make/load history files in each
    """
        if not os.path.isdir(d):
            raise Exception('%s is not a valid directory' % d)
        f = os.path.join(d, HIST_FNAME)
        if os.path.isfile(f):
            return np.load(f, allow_pickle=True).item()
        else:
            return defaultdict(dict)

    hist_dir = os.path.join(config['weights_root'], config['experiment_name'])
    hist = load_or_make_hist(hist_dir)

    if config['get_test_error'] or config['get_train_error']:
        loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                'use_test_set': config['get_test_error']
            })
        acc_type = 'Test' if config['get_test_error'] else 'Train'

        pbar = tqdm(loaders[0])
        loader_total = len(loaders[0]) * config['batch_size']
        sample_todo = min(config['sample_num_error'], loader_total)
        print('Getting %s error accross %i examples' % (acc_type, sample_todo))
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (x, y) in enumerate(pbar):
                correct += get_n_correct_from_D(x, y)
                total += config['batch_size']
                if loader_total > total and total >= config['sample_num_error']:
                    print('Quitting early...')
                    break

        accuracy = float(correct) / float(total)
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']][acc_type] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], acc_type, accuracy * 100))

    if config['get_self_error']:
        n_used_imgs = config['sample_num_error']
        correct = 0
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                correct += get_n_correct_from_D(images, y)

        accuracy = float(correct) / float(n_used_imgs)
        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], 'Self', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Self'] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['get_generator_error']:

        if config['dataset'] == 'C10':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121()
            compnet = torch.nn.DataParallel(compnet)
            #checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7'))
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar/densenet121',
                    'ckpt_47.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        elif config['dataset'] == 'C100':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121(num_classes=100)
            compnet = torch.nn.DataParallel(compnet)
            checkpoint = torch.load(
                os.path.join(
                    '/scratch0/ilya/locDoc/classifiers/cifar100/densenet121',
                    'ckpt.copy.t7'))
            #checkpoint = torch.load(os.path.join('/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar100/densenet121','ckpt.copy.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.507, 0.487, 0.441),
                                     (0.267, 0.256, 0.276)),
            ])
        elif config['dataset'] == 'STL48':
            from classification.models.wideresnet import WideResNet48
            from torchvision import transforms
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/stl/mixmatch_48',
                    'model_best.pth.tar'))
            compnet = WideResNet48(num_classes=10)
            compnet = compnet.to(device)
            for param in compnet.parameters():
                param.detach_()
            compnet.load_state_dict(checkpoint['ema_state_dict'])
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            raise ValueError('Dataset %s has no comparison network.' %
                             config['dataset'])

        n_used_imgs = 10000
        correct = 0
        mean_label = np.zeros(config['n_classes'])
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                fake = images.data.cpu().numpy()
                fake = np.floor((fake + 1) * 255 / 2.0).astype(np.uint8)
                fake_input = np.zeros(fake.shape)
                for bi in range(fake.shape[0]):
                    fake_input[bi] = minimal_trans(np.moveaxis(
                        fake[bi], 0, -1))
                images.data.copy_(torch.from_numpy(fake_input))
                lab = compnet(images).max(1)[1]
                mean_label += np.bincount(lab.data.cpu(),
                                          minlength=config['n_classes'])
                correct += int((lab == y).sum().cpu())

        accuracy = float(correct) / float(n_used_imgs)
        mean_label_normalized = mean_label / float(n_used_imgs)

        print(
            '[%s][%06d] %s accuracy: %f.' %
            (brief_expt_name, state_dict['itr'], 'Generator', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Generator'] = accuracy
        hist[state_dict['itr']]['Mean_Label'] = mean_label_normalized
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    if config['official_FID']:
        f = np.load(config['dataset_is_fid'])
        # this is for using the downloaded one from
        # https://github.com/bioinf-jku/TTUR
        #mdata, sdata = f['mu'][:], f['sigma'][:]

        # this one is for my format files
        mdata, sdata = f['mfid'], f['sfid']

    # Sample a number of images and stick them in memory, for use with TF-Inception official_IS and official_FID
    data_gen_necessary = False
    if config['sample_np_mem']:
        is_saved = int('IS' in hist[state_dict['itr']])
        is_todo = int(config['official_IS'])
        fid_saved = int('FID' in hist[state_dict['itr']])
        fid_todo = int(config['official_FID'])
        data_gen_necessary = config['overwrite'] or (is_todo > is_saved) or (
            fid_todo > fid_saved)
    if config['sample_np_mem'] and data_gen_necessary:
        n_used_imgs = 50000
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            start = l * G_batch_size
            end = start + G_batch_size

            with torch.no_grad():
                images, labels = sample()
            fake = np.uint8(255 * (images.cpu().numpy() + 1) / 2.)
            x[start:end] = np.moveaxis(fake, 1, -1)
            #y += [labels.cpu().numpy()]

    if config['official_IS']:
        if (not ('IS' in hist[state_dict['itr']])) or config['overwrite']:
            mis, sis = iscore.get_inception_score(x)
            print('[%s][%06d] IS mu: %f. IS sigma: %f.' %
                  (brief_expt_name, state_dict['itr'], mis, sis))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['IS'] = [mis, sis]
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            mis, sis = hist[state_dict['itr']]['IS']
            print(
                '[%s][%06d] Already done (skipping...): IS mu: %f. IS sigma: %f.'
                % (brief_expt_name, state_dict['itr'], mis, sis))

    if config['official_FID']:
        import tensorflow as tf

        def fid_ms_for_imgs(images, mem_fraction=0.5):
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=mem_fraction)
            inception_path = fid.check_or_download_inception(None)
            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session(config=tf.ConfigProto(
                    gpu_options=gpu_options)) as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)
            return mu_gen, sigma_gen

        if (not ('FID' in hist[state_dict['itr']])) or config['overwrite']:
            m1, s1 = fid_ms_for_imgs(x)
            fid_value = fid.calculate_frechet_distance(m1, s1, mdata, sdata)
            print('[%s][%06d] FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['FID'] = fid_value
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            fid_value = hist[state_dict['itr']]['FID']
            print('[%s][%06d] Already done (skipping...): FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=folder_number,
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=int(folder_number),
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(
            images.float(),
            '%s/%s/%s.jpg' %
            (config['samples_root'], experiment_name, config['load_weights']),
            nrow=int(G_batch_size**0.5),
            normalize=True)

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        # Get Inception Score and FID
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Beispiel #18
0
"""
basic lienar regression of the primary and secondary color to the composisiton
"""
from datasets import GlazeColor2CompositionDataset
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from utils import get_data_loaders

ds = GlazeColor2CompositionDataset()
train_ds, test_ds = get_data_loaders(ds)

reg = linear_model.LinearRegression()
train_x = [d for d, _ in train_ds]
train_y = [d.numpy() for _, d in train_ds]
print(ds[0][0])
print(ds[0][1].numpy())
print(train_ds[0][0])
print(train_ds[0][1].numpy())
print(train_y[0])

reg.fit(train_x, train_y)
Beispiel #19
0
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    if config['resume']:
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    utils.seed_rng(config['seed'])
    utils.prepare_root(config)
    torch.backends.cudnn.benchmark = True
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    if config['ema']:
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None
    if config['G_fp16']:
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        D = D.half()
    GD = model.G_D(G, D, config['conditional'])
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }
    if config['resume']:
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)
    utils.load_weights(
        G, D, state_dict,
        '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5',
        'best0', G_ema if config['ema'] else None)
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    # Use: config['abnormal_class']
    #print(config['abnormal_class'])
    #abnormal_class = config['abnormal_class']
    #select_dataset = config['select_dataset']
    #print(config['select_dataset'])
    #print(abnormal_class)
    #print(select_dataset)
    abnormal_class = config['abnormal_class']
    select_dataset = config['select_dataset']
    loaders = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr'],
            'abnormal_class': abnormal_class,
            'select_dataset': select_dataset
        })
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    else:
        train = train_fns.dummy_training_function()
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)
    if config['dataset'] == 'C10U' or config['dataset'] == 'C10':
        data_moments = 'fid_stats_cifar10_train.npz'
        #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
        #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
    else:
        print("Cannot find the image data set.")
        sys.exit()
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            state_dict['itr'] += 1
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            print('')
            # Random seed
            #print(config['seed'])
            if epoch == 0 and i == 0:
                print(config['seed'])
            # We double the learning rate if we double the batch size.
            metrics = train(x, y)
            train_log.log(itr=int(state_dict['itr']), **metrics)
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)
            experiment_name = (config['experiment_name']
                               if config['experiment_name'] else
                               utils.name_from_config(config))
            if (not (state_dict['itr'] % config['test_every'])) and (
                    epoch >= config['start_eval']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                utils.sample_inception(
                    G_ema if config['ema'] and config['use_ema'] else G,
                    config, str(epoch))
                folder_number = str(epoch)
                sample_moments = '%s/%s/%s/samples.npz' % (
                    config['samples_root'], experiment_name, folder_number)
                FID = fid_score.calculate_fid_given_paths(
                    [data_moments, sample_moments],
                    batch_size=50,
                    cuda=True,
                    dims=2048)
                train_fns.update_FID(G, D, G_ema, state_dict, config, FID,
                                     experiment_name, test_log)
        state_dict['epoch'] += 1
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, 'last%d' % 0,
                       G_ema if config['ema'] else None)
Beispiel #20
0
def run(config):
    if 'hdf5' in config['dataset']:
        raise ValueError(
            'Reading from an HDF5 file which you will probably be '
            'about to overwrite! Override this error only if you know '
            'what you'
            're doing!')
    # Get image size
    config['image_size'] = utils.imsize_dict[config['dataset']]

    # Update compression entry
    config['compression'] = 'lzf' if config[
        'compression'] else None  #No compression; can also use 'lzf'

    # Get dataset
    kwargs = {
        'num_workers': config['num_workers'],
        'pin_memory': False,
        'drop_last': False
    }
    train_loader = utils.get_data_loaders(dataset=config['dataset'],
                                          batch_size=config['batch_size'],
                                          shuffle=False,
                                          data_root=config['data_root'],
                                          use_multiepoch_sampler=False,
                                          **kwargs)[0]

    # HDF5 supports chunking and compression. You may want to experiment
    # with different chunk sizes to see how it runs on your machines.
    # Chunk Size/compression     Read speed @ 256x256   Read speed @ 128x128  Filesize @ 128x128    Time to write @128x128
    # 1 / None                   20/s
    # 500 / None                 ramps up to 77/s       102/s                 61GB                  23min
    # 500 / LZF                                         8/s                   56GB                  23min
    # 1000 / None                78/s
    # 5000 / None                81/s
    # auto:(125,1,16,32) / None                         11/s                  61GB

    print(
        'Starting to load %s into an HDF5 file with chunk size %i and compression %s...'
        % (config['dataset'], config['chunk_size'], config['compression']))
    # Loop over train loader
    for i, (x, y) in enumerate(tqdm(train_loader)):
        # Stick X into the range [0, 255] since it's coming from the train loader
        pdb.set_trace()
        x = (255 * ((x + 1) / 2.0)).byte().numpy()
        # Numpyify y
        y = y.numpy()
        # If we're on the first batch, prepare the hdf5
        if i == 0:
            with h5.File(
                    config['data_root'] +
                    '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f:
                print('Producing dataset of len %d' %
                      len(train_loader.dataset))
                imgs_dset = f.create_dataset(
                    'imgs',
                    x.shape,
                    dtype='uint8',
                    maxshape=(len(train_loader.dataset), 3,
                              config['image_size'], config['image_size']),
                    chunks=(config['chunk_size'], 3, config['image_size'],
                            config['image_size']),
                    compression=config['compression'])
                print('Image chunks chosen as ' + str(imgs_dset.chunks))
                imgs_dset[...] = x
                labels_dset = f.create_dataset(
                    'labels',
                    y.shape,
                    dtype='int64',
                    maxshape=(len(train_loader.dataset), ),
                    chunks=(config['chunk_size'], ),
                    compression=config['compression'])
                print('Label chunks chosen as ' + str(labels_dset.chunks))
                labels_dset[...] = y
        # Else append to the hdf5
        else:
            with h5.File(
                    config['data_root'] +
                    '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f:
                f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0)
                f['imgs'][-x.shape[0]:] = x
                f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0)
                f['labels'][-y.shape[0]:] = y
def main():
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    device = torch.device('cpu')

    parser = argparse.ArgumentParser(description='Training of HiDDeN nets')
    # parser.add_argument('--size', '-s', default=128, type=int, help='The size of the images (images are square so this is height and width).')
    parser.add_argument('--data-dir',
                        '-d',
                        required=True,
                        type=str,
                        help='The directory where the data is stored.')
    parser.add_argument(
        '--runs_root',
        '-r',
        default=os.path.join('.', 'experiments'),
        type=str,
        help='The root folder where data about experiments are stored.')

    args = parser.parse_args()
    print_each = 25

    completed_runs = [
        o for o in os.listdir(args.runs_root)
        if os.path.isdir(os.path.join(args.runs_root, o))
        and o != 'no-noise-defaults'
    ]

    print(completed_runs)

    write_csv_header = True
    for run_name in completed_runs:
        current_run = os.path.join(args.runs_root, run_name)
        print(f'Run folder: {current_run}')
        options_file = os.path.join(current_run, 'options-and-config.pickle')
        train_options, hidden_config, noise_config = utils.load_options(
            options_file)
        train_options.train_folder = os.path.join(args.data_dir, 'val')
        train_options.validation_folder = os.path.join(args.data_dir, 'val')
        train_options.batch_size = 4
        checkpoint = utils.load_last_checkpoint(
            os.path.join(current_run, 'checkpoints'))

        noiser = Noiser(noise_config, device)
        model = Hidden(hidden_config, device, noiser, tb_logger=None)
        utils.model_from_checkpoint(model, checkpoint)

        print('Model loaded successfully. Starting validation run...')
        _, val_data = utils.get_data_loaders(hidden_config, train_options)
        file_count = len(val_data.dataset)
        if file_count % train_options.batch_size == 0:
            steps_in_epoch = file_count // train_options.batch_size
        else:
            steps_in_epoch = file_count // train_options.batch_size + 1

        losses_accu = {}
        step = 0
        for image, _ in val_data:
            step += 1
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (encoded_images, noised_images,
                     decoded_messages) = model.validate_on_batch(
                         [image, message])
            if not losses_accu:  # dict is empty, initialize
                for name in losses:
                    losses_accu[name] = []
            for name, loss in losses.items():
                losses_accu[name].append(loss)
            if step % print_each == 0:
                print(f'Step {step}/{steps_in_epoch}')
                utils.print_progress(losses_accu)
                print('-' * 40)

        utils.print_progress(losses_accu)
        write_validation_loss(os.path.join(args.runs_root,
                                           'validation_run.csv'),
                              losses_accu,
                              run_name,
                              checkpoint['epoch'],
                              write_header=write_csv_header)
        write_csv_header = False
Beispiel #22
0
def run(config):

    # Number of classes, add it to the config
    config['num_classes'] = utils.num_class_dict[config['dataset']]

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # choose device; by default (and unless this code is modified) it's GPU 0
    device = 'cuda:0'

    # Create logs and weights folder if they don't exist
    if not os.path.exists('logs'):
        os.mkdir('logs')
    if not os.path.exists('weights'):
        os.mkdir('weights')

    # Name of the file to which we're saving losses and errors.
    experiment_name = utils.name_from_config(config) if config[
        'experiment_name'] is not 'default' else config['experiment_name']

    # Prepare metrics logging
    metrics_fname = 'logs/' + experiment_name + '_log.jsonl'
    print('Metrics will be saved to {}'.format(metrics_fname))
    mlog = utils.MetricsLogger(metrics_fname,
                               reinitialize=(not config['resume']))

    # Import the model module
    sys.path.append('models')
    model_module = __import__(
        config['model'])  #.strip('.py')) # remove accidental.py

    # Build network, either by initializing it or re-lodaing.
    net = model_module.Network(**config)
    net = net.to(device)
    print(net)
    if config['fp16']:
        net = net.half()

    # prepare the script state dict
    state_dict = {'epoch': 0, 'batch_size': config['batch_size']}

    if config['resume']:
        print('loading network ' + experiment_name + '...')
        net = torch.load('weights/%s.pth' % experiment_name)
        # net.load_state_dict('weights/%s.pth' % experiment_name)
        # net.optim.load_state_dict('weights/%s_optim.pth' % experiment_name)

    # Prepare the run_net
    if config['parallel']:
        print('Parallelizing net...')
        run_net = nn.DataParallel(net)
    else:
        run_net = net

    start_epoch = net.epoch + 1 if hasattr(net, 'epoch') else 0
    print('Number of params: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))
    # If not validating, set val split to 0
    if not config['validate']:
        config['validate_split'] = 0
    # Get information specific to each dataset
    loaders = utils.get_data_loaders(config['dataset'], config['augment'],
                                     config['validate'], config['test'],
                                     config['batch_size'], config['fold'],
                                     config['validate_seed'],
                                     config['validate_split'] / 100.,
                                     config['num_workers'])
    train_loader = loaders[0]

    if config['validate']:
        val_loader = loaders[1]
    if config['test']:
        test_loader = loaders[2]

    # Training Function, presently only returns training loss
    # x: input data
    # y: target labels
    def train_fn(x, y):
        net.optim.zero_grad()
        output = run_net(x)
        loss = F.nll_loss(output, y)
        training_error = output.data.max(1)[1].ne(y).sum()
        loss.backward()
        net.optim.step()
        return loss.data.item(), training_error

    # Testing function, returns test loss and test error for a batch
    # x: input data
    # y: target labels
    def test_fn(x, y):
        with torch.no_grad():
            output = run_net(x)
            test_loss = F.nll_loss(output, y)

            # If we're running Imagenet, we may want top-5 error:
            if config['top5']:
                top5_preds = np.argsort(output.data.cpu().numpy())[:, :-6:-1]
                test_error = len(y) - np.sum([
                    np.any(top5_i == y_i)
                    for top5_i, y_i in zip(top5_preds, y)
                ])
            else:
                # Get the index of the max log-probability as the prediction.
                pred = output.data.max(1)[1]
                test_error = pred.ne(y).sum()

            return test_loss.data.item(), test_error

    # Finally, launch the training loop.
    print('Starting training at epoch ' + str(start_epoch) + '...')
    for epoch in range(start_epoch, config['epochs']):
        # Pin the current epoch on the network.
        net.epoch = epoch

        # shrink learning rate at scheduled intervals, if desired
        if 'epoch' in net.lr_sched and epoch in net.lr_sched['epoch']:

            print('Annealing learning rate...')

            # Optionally checkpoint at annealing
            if net.checkpoint_before_anneal:
                torch.save(
                    net,
                    'weights/' + str(epoch) + '_' + experiment_name + '.pth')

            for param_group in net.optim.param_groups:
                param_group['lr'] *= 0.1

        # List where we'll store training loss
        train_loss, train_err = [], []

        # Prepare the training data
        if config['progbar']:
            batches = utils.progress(
                train_loader,
                desc='Epoch %d/%d, Batch ' % (epoch + 1, config['epochs']),
                total=len(train_loader.dataset) // config['batch_size'])
        else:
            batches = train_loader

        # Put the network into training mode
        net.train()

        # Execute training pass
        for x, y in batches:
            # Update LR if using cosine annealing
            if 'itr' in net.lr_sched:
                net.update_lr(max_j=config['epochs'] *
                              len(train_loader.dataset) //
                              config['batch_size'])

            loss, err = train_fn(x.to(device), y.to(device))
            train_loss.append(loss)
            train_err.append(err)

        # Report training metrics
        train_loss = float(np.mean(train_loss))
        train_err = 100 * float(np.sum(train_err)) / len(train_loader.dataset)
        print('  training loss:\t%.6f, training error: \t%.2f%%' %
              (train_loss, train_err))
        mlog.log(epoch=epoch, train_loss=train_loss, train_err=train_err)

        # Optionally, take a pass over the validation set.
        if config['validate'] and not ((epoch + 1) % config['validate_every']):

            # Lists to store
            val_loss, val_err = [], []

            # Set network into evaluation mode
            net.eval()

            # Execute validation pass
            if config['progbar']:
                batches = tqdm(val_loader)
            else:
                batches = val_loader

            for x, y in batches:
                loss, err = test_fn(x.to(device), y.to(device))
                val_loss.append(loss)
                val_err.append(err)

            # Report validation metrics
            val_loss = float(np.mean(val_loss))
            val_err = 100 * float(np.sum(val_err)) / len(val_loader.dataset)
            print('  validation loss:\t%.6f,  validation error:\t%.2f%%' %
                  (val_loss, val_err))
            mlog.log(epoch=epoch, val_loss=val_loss, val_err=val_err)

        # Optionally, take a pass over the validation or test set.
        if config['test'] and not ((epoch + 1) % config['validate_every']):
            # Lists to store
            test_loss, test_err = [], []

            # Set network into evaluation mode
            net.eval()

            # Execute validation pass
            if config['progbar']:
                batches = tqdm(test_loader)
            else:
                batches = test_loader

            for x, y in batches:
                loss, err = test_fn(x.to(device), y.to(device))
                test_loss.append(loss)
                test_err.append(err)

            # Report validation metrics
            test_loss = float(np.mean(test_loss))
            test_err = 100 * float(np.sum(test_err)) / len(test_loader.dataset)
            print('  test loss:\t%.6f,  test error:\t%.2f%%' %
                  (test_loss, test_err))
            mlog.log(epoch=epoch, test_loss=test_loss, test_err=test_err)
        # Save weights for this epoch
        if type(net) is nn.DataParallel:
            print('saving de-parallelized weights to ' + experiment_name +
                  '...')
            torch.save(net.module, 'weights/%s.pth' % experiment_name)
        else:
            print('saving weights to %s...' % experiment_name)
            torch.save(net, 'weights/%s.pth' % experiment_name)

        # If requested, save a checkpointed copy with a different name
        # so that we have them for reference later.

    # At the end of it all, save weights even if we didn't checkpoint.
    if experiment_name:
        torch.save(net, 'weights/%s.pth' % experiment_name)
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    D = model.Discriminator(**config).to(device)

    # FP16?
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?

    print(D)

    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'config': config}

    # If parallel, parallelize the GD module
    if config['parallel']:
        D = nn.DataParallel(D)

        if config['cross_replica']:
            patch_replication_callback(D)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # set tensorboard logger
    tb_logdir = '%s/%s/tblogs' % (config['logs_root'], experiment_name)
    if os.path.exists(tb_logdir):
        for filename in os.listdir(tb_logdir):
            if filename.startswith('events'):
                os.remove(os.path.join(tb_logdir,
                                       filename))  # remove previous event logs
    tb_writer = SummaryWriter(log_dir=tb_logdir)
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'MINE':
        train = train_fns.MINE_training_function(D, state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()

    print('Beginning training at epoch %d...' % state_dict['epoch'])

    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own (mine, ok)?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            D.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)
            print(metrics)
            train_log.log(itr=int(state_dict['itr']), **metrics)
            for metric_name in metrics:
                tb_writer.add_scalar('Train/%s' % metric_name,
                                     metrics[metric_name], state_dict['itr'])

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Beispiel #24
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        ema = None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    print(G)
    print(D)
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Beispiel #25
0
def run(config):
    # Get loader
    config['drop_last'] = False
    loaders = utils.get_data_loaders(**config)
    
    # build graph
    INCEPTION_FINAL_POOL = 'pool_3:0'
    INCEPTION_DEFAULT_IMAGE_SIZE = 299
    ACTIVATION_DIM = 2048

    def inception_activations(images, height=INCEPTION_DEFAULT_IMAGE_SIZE, width=INCEPTION_DEFAULT_IMAGE_SIZE, num_splits = 1):
        images = tf.image.resize_bilinear(images, [height, width])
        generated_images_list = array_ops.split(images, num_or_size_splits = num_splits)
        activations = functional_ops.map_fn(
            fn = functools.partial(tfgan.eval.run_inception, output_tensor = INCEPTION_FINAL_POOL),
            elems = array_ops.stack(generated_images_list),
            parallel_iterations = 1,
            back_prop = False,
            swap_memory = True,
            name = 'RunClassifier')
        activations = array_ops.concat(array_ops.unstack(activations), 0)
        return activations

    images_holder = tf.placeholder(tf.float32, [None, 128, 128, 3])
    activations = inception_activations(images_holder)
    real_acts = tf.placeholder(tf.float32, [None, ACTIVATION_DIM], name = 'real_activations')
    fake_acts = tf.placeholder(tf.float32, [None, ACTIVATION_DIM], name = 'fake_activations')
    fid = tfgan.eval.frechet_classifier_distance_from_activations(real_acts, fake_acts)
    
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.5

    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    sess.graph.finalize()
    
    device = 'cuda'
    out_real_acts = []
    prev_y = 0
    for i, (x, y) in enumerate(tqdm(loaders[0])):
        x = np.transpose(x.numpy(), (0, 2, 3, 1)) # [NCHW] -> [NHWC]
        y = y.numpy()

        if y[0] != prev_y: # label changes between the preveious & cureent batch
            out_real_acts = np.concatenate(out_real_acts, axis=0)
            np.save('real_act/act_{}.npy'.format(prev_y), out_real_acts)
            print('# A class finished !', out_real_acts.shape)
            out_real_acts = []
            
        mask = (y == y[0])
        out_act = sess.run(activations, {images_holder: x})
            
        if np.sum(mask) != len(x): # label changes inside this batch
            out_real_acts.append(out_act[mask])
            out_real_acts = np.concatenate(out_real_acts, axis=0)
            np.save('real_act/act_{}.npy'.format(y[0]), out_real_acts)
            print('@ A class finished !', out_real_acts.shape)
            out_real_acts = [out_act[~mask]]
        else:
            out_real_acts.append(out_act)
        prev_y = y[-1]

    np.save('real_act/act_999.npy', out_real_acts) # the last class