예제 #1
0
def train(trainX, trainU, model, ema_model, optimizer, epoch):
    xe_loss_avg = tf.keras.metrics.Mean()
    l2u_loss_avg = tf.keras.metrics.Mean()
    total_loss_avg = tf.keras.metrics.Mean()
    accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

    shuffle_and_batch = lambda dataset: dataset.shuffle(buffer_size=int(1e6)).batch(batch_size=64, drop_remainder=True)

    iteratorX = iter(shuffle_and_batch(trainX))
    iteratorU = iter(shuffle_and_batch(trainU))

    progress_bar = tqdm(range(1024), unit='batch')
    for batch_num in progress_bar:
        lambda_u = 100 * linear_rampup(epoch + batch_num/1024, 16)
        try:
            batchX = next(iteratorX)
        except:
            iteratorX = iter(shuffle_and_batch(trainX))
            batchX = next(iteratorX)
        try:
            batchU = next(iteratorU)
        except:
            iteratorU = iter(shuffle_and_batch(trainU))
            batchU = next(iteratorU)

        #args['beta'].assign(np.random.beta(args['alpha'], args['alpha']))
        beta = np.random.beta(0.75,0.75)
        with tf.GradientTape() as tape:
            # run mixmatch
            XU, XUy = mixmatch(model, batchX['image'], batchX['label'], batchU['image'], 0.5, 2, beta)
            logits = [model(XU[0])]
            for batch in XU[1:]:
                logits.append(model(batch))
            logits = interleave(logits, 64)
            logits_x = logits[0]
            logits_u = tf.concat(logits[1:], axis=0)

            # compute loss
            xe_loss, l2u_loss = semi_loss(XUy[:64], logits_x, XUy[64:], logits_u)
            total_loss = xe_loss + lambda_u * l2u_loss

        # compute gradients and run optimizer step
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        ema(model, ema_model, 0.999)
        weight_decay(model=model, decay_rate=0.02 * 0.01)

        xe_loss_avg(xe_loss)
        l2u_loss_avg(l2u_loss)
        total_loss_avg(total_loss)
        accuracy(tf.argmax(batchX['label'], axis=1, output_type=tf.int32), model(tf.cast(batchX['image'], dtype=tf.float32), training=False))

        progress_bar.set_postfix({
            'XE Loss': f'{xe_loss_avg.result():.4f}',
            'L2U Loss': f'{l2u_loss_avg.result():.4f}',
            'WeightU': f'{lambda_u:.3f}',
            'Total Loss': f'{total_loss_avg.result():.4f}',
            'Accuracy': f'{accuracy.result():.3%}'
        })
    return xe_loss_avg, l2u_loss_avg, total_loss_avg, accuracy
 def predict_workload(self):
     predict = {}
     self.statistics = self.get_host_data(use_statistics=True)[1]
     for param, dataset in self.statistics.iteritems():
         predict[param] = {}
         step = 5
         next_time_points = [float(t)+float(step) for t,v in dataset]
         predict_data_points = utils.ema([float(v) for t,v in dataset])
         predict[param] = zip(next_time_points,predict_data_points)
     return predict
 def force_index(self):
     '''
     Force Index
     
     Indicates how strong the actual buying or selling pressure is.
     High = Rising Trend
     Low = Downward Trend
     
     ref: https://school.stockcharts.com/doku.php?id=technical_indicators:force_index
     
     '''
     fi = (self.data.Close - self.data.Close.shift(1)) * self.data.Volume
     return utils.ema(fi)
def task_current_signal(sma=config.SMA,
                        ema=config.EMA,
                        table='historical_ratios'):
    time.sleep(30)

    tickers = utils.currencies()
    tickers = tickers[0]

    if sma > ema:
        cant_rows = (sma * 2) + 10
    else:
        cant_rows = (ema * 2) + 10

    while app.running:
        for ticker in tickers:

            try:

                data = utils.bring_data_db(ticker=ticker,
                                           k=cant_rows,
                                           table=table)
                sma_value = utils.sma(data=data, k=sma)
                ema_value = utils.ema(data=data, k=ema)

                if ema_value and sma_value:
                    if ema_value <= sma_value:
                        resultado = 'open'

                    else:
                        resultado = 'close'

                    save_current_signal(ticker, resultado)

            except:
                traceback.print_exc()
                print(f'error con {ticker}')

        time.sleep(15)
예제 #5
0
def run(config):

  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config = utils.update_config_roots(config)
  device = 'cuda'
  if config['base_root']:
    os.makedirs(config['base_root'],exist_ok=True)

  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  
   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  
  # FP16?
  if config['G_fp16']:
    print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    print('Casting D to fp16...')
    D = D.half()
    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  print(G)
  print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None,
                       )
    if G.lr_sched is not None:G.lr_sched.step(state_dict['epoch'])
    if D.lr_sched is not None:D.lr_sched.step(state_dict['epoch'])

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)
    if config['cross_replica']:
      patch_replication_callback(GD)

  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  test_log = utils.MetricsLogger(test_metrics_fname, 
                                 reinitialize=(not config['resume']))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr']})

  # Prepare inception metrics: FID and IS
  if not config['on_kaggle']:
    get_inception_metrics = inception_utils.prepare_inception_metrics(config['base_root'],config['dataset'], config['parallel'], config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  if config['use_dog_cnt']:
    y_dist='categorical_dog_cnt'
  else:
    y_dist = 'categorical'

  dim_z=G.dim_z*2 if config['mix_style'] else G.dim_z
  z_, y_ = utils.prepare_z_y(G_batch_size, dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'],z_dist=config['z_dist'],
                             threshold=config['truncated_threshold'],y_dist=y_dist)
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, dim_z,
                                       config['n_classes'], device=device,
                                       fp16=config['G_fp16'],z_dist=config['z_dist'],
                                       threshold=config['truncated_threshold'],y_dist=y_dist)
  fixed_z.sample_()
  fixed_y.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  sample = functools.partial(utils.sample,
                              G=(G_ema if config['ema'] and config['use_ema']
                                 else G),
                              z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  #I find by epoch is more convelient,so I suggest change to it.if save_every<100,I will change to py epoch
  by_epoch=False if config['save_every']>100 else True


  # Train for specified number of epochs, although we mostly track G iterations.
  start_time = time.time()
  for epoch in range(state_dict['epoch'], config['num_epochs']):
    # Which progressbar to use? TQDM or my own?
    if config['on_kaggle']:
      pbar = loaders[0]
    elif config['pbar'] == 'mine':
      pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
      pbar = tqdm(loaders[0])
    epoch_start_time = time.time()
    for i, (x, y) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      if config['ema']:
        G_ema.train()
      if type(y) == list or type(y)==tuple:
        y=torch.cat([yi.unsqueeze(1) for yi in y],dim=1)

      if config['D_fp16']:
        x, y = x.to(device).half(), y.to(device)
      else:
        x, y = x.to(device), y.to(device)
      metrics = train(x, y)
      train_log.log(itr=int(state_dict['itr']), **metrics)
      
      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # If using my progbar, print metrics.
      if config['on_kaggle']:
        if i == len(loaders[0])-1:
          metrics_str = ', '.join(['%s : %+4.3f' % (key, metrics[key]) for key in metrics])
          epoch_time = (time.time()-epoch_start_time) / 60
          total_time = (time.time()-start_time) / 60
          print(f"[{epoch+1}/{config['num_epochs']}][{epoch_time:.1f}min/{total_time:.1f}min] {metrics_str}")
      elif config['pbar'] == 'mine':
        if D.lr_sched is None:
          print(', '.join(['epoch:%d' % (epoch+1),'itr: %d' % state_dict['itr']]
                       + ['%s : %+4.3f' % (key, metrics[key])
                       for key in metrics]), end=' ')
        else:
          print(', '.join(['epoch:%d' % (epoch+1),'lr:%.5f' % D.lr_sched.get_lr()[0] ,'itr: %d' % state_dict['itr']]
                       + ['%s : %+4.3f' % (key, metrics[key])
                       for key in metrics]), end=' ')
      if not by_epoch:
        # Save weights and copies as configured at specified interval
        if not (state_dict['itr'] % config['save_every']) and not config['on_kaggle']:
          if config['G_eval_mode']:
            print('Switchin G to eval mode...')
            G.eval()
            if config['ema']:
              G_ema.eval()
          train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                    state_dict, config, experiment_name)

        # Test every specified interval
        if not (state_dict['itr'] % config['test_every']) and not config['on_kaggle']:
          if config['G_eval_mode']:
            print('Switchin G to eval mode...')
            G.eval()
          train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                         get_inception_metrics, experiment_name, test_log)

    if by_epoch:
      # Save weights and copies as configured at specified interval
      if not ((epoch+1) % config['save_every']) and not config['on_kaggle']:
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                  state_dict, config, experiment_name)

      # Test every specified interval
      if not ((epoch+1) % config['test_every']) and not config['on_kaggle']:
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)

      if G_ema is not None and (epoch+1) % config['test_every'] == 0 and not config['on_kaggle']:
        torch.save(G_ema.state_dict(),  '%s/%s/G_ema_epoch_%03d.pth' %
                   (config['weights_root'], config['experiment_name'], epoch+1))
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
    if G.lr_sched is not None:
      G.lr_sched.step()
    if D.lr_sched is not None:
      D.lr_sched.step()
  if config['on_kaggle']:
    train_fns.generate_submission(sample, config, experiment_name)
예제 #6
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    ## *** 新增 resolution 使用 I128_hdf5 数据集, 这里也许需要使用 C10数据集
    config['resolution'] = utils.imsize_dict[config['dataset']]
    ## *** 新增 nclass_dict 加载 I128_hdf5 的类别, 这里也许需要使用 C10的类别 10类
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    ## 加载 GD的 激活函数, 都用Relu, 这里的Relu是小写,不知道是否要改大写R
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]

    ## 从头训练吧,么有历史的参数,不用改,默认的就是
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True

    ## 日志加载,也不用改应该
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    ## 设置初始随机数种子,都为0,*** 需要修改为paddle的设置
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    ## 设置日志根目录,这个应该也不用改
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    ## @@@ 这里不需要更改,直接注释掉,Paddle不一定需要这个设置
    ## 用于加速固定网络结构的参数
    # torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    ## *** !!! 这个方法很酷哦,直接导入BigGan的model,要看一下BigGAN里面的网络结构配置
    model = __import__(config['model'])
    ## 不用改,把一系列配置作为名字放到了实验名称中
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    ## *** 导入参数,需要修改两个方法
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    ## *** 默认不开,可以先不改EMA部分
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    ## C10比较小,G和D这部分也可以暂时不改,使用默认精度
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    ## 把设置完结构G和D打包放入结构模型G_D中
    GD = model.G_D(G, D)
    ## *** 这两个print也许可以删掉,没必要。可能源于继承的nn.Module的一些打印属性
    print(G)
    print(D)
    ## *** 这个parameters也是继承torch的属性
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    ## 初始化统计参数记录表 不用变动
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    ## 暂时不用预训练,所以这一块不用更改
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    ## 暂时不用管,GD 默认不并行
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    ## 日志中心,应该也可以不用管,如果需要就是把IS和FID的结果看看能不能抽出来
    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])

    ## 这个才是重要的,这个是用来做结果统计的。
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    ## *** D的数据加载,加载的过程中,get_data_loaders用到了torchvision的transforms方法
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    ## 准备评价指标,FID和IS的计算流程,可以使用np版本计算,也不用改
    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    ## 准备噪声和随机采样的标签组
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])

    # Prepare a fixed z & y to see individual sample evolution throghout training
    ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签
    ## TODO 获得两份噪声和标签,有社么用意吗?
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])

    ## *** 从Distribution中获得采样的方法,可以选择高斯采样和categorical采样
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    ## *** 实例化GAN_training_function训练流程
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    ## 如果没有指定训练模型,那么就用假训走一下流程Debug
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    ## *** 把函数utils.sample中部分入参事先占掉,定义为新的函数sample
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            ## 这一部分无需翻
            ## !!! loaders[0] 代表了数据采样对象
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            ## *** 继承nn.Module中的train, 对应的是
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()

            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            ## *** 把数据和标签放入训练函数里,train本身有很多需要改写
            metrics = train(x, y)
            ## 记录日志,把metrics信息都输入日志
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            ## 记录资格迹的变化日志
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            ## 默认每2000步记录一次结果
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    ## *** module中的方法
                    G.eval()
                    ## 如果采用指数滑动平均
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            ## 默认每5000步测试一次
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
    step_updates, task_updates, task_metrics, regularizer_fn
"""



PATH_INT_PROTOCOL = lambda omega_decay, xi: (
        'path_int[omega_decay=%s,xi=%s]'%(omega_decay,xi),
{
    'init_updates':  [
        ('cweights', lambda vars, w, prev_val: w.value() ),
        ],
    'step_updates':  [
        ('grads2', lambda vars, w, prev_val: prev_val -vars['unreg_grads'][w] * vars['deltas'][w] ),
        ],
    'task_updates':  [
        ('omega',     lambda vars, w, prev_val: tf.nn.relu( ema(omega_decay, prev_val, vars['grads2'][w]/((vars['cweights'][w]-w.value())**2+xi)) ) ),
        #('cached_grads2', lambda vars, w, prev_val: vars['grads2'][w]),
        #('cached_cweights', lambda vars, w, prev_val: vars['cweights'][w]),
        ('cweights',  lambda opt, w, prev_val: w.value()),
        ('grads2', lambda vars, w, prev_val: prev_val*0.0 ),
    ],
    'regularizer_fn': quadratic_regularizer,
})


FISHER_PROTOCOL = lambda omega_decay:(
    'fisher[omega_decay=%s]'%omega_decay,
{
    'task_updates':  [
        ('omega', lambda vars, w, prev_val: ema(omega_decay, prev_val, vars['task_fisher'][w]/vars['nb_data'])),
        ('cweights', lambda opt, w, prev_val: w.value()),
예제 #8
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    num_devices = torch.cuda.device_count()
    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.ImageDiscriminator(**config).to(device)
    if config['no_Dv'] == False:
        Dv = model.VideoDiscriminator(**config).to(device)
    else:
        Dv = None

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        if config['no_Dv'] == False:
            Dv = Dv.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(
        G, D, Dv, config['k'],
        config['T_into_B'])  #xiaodan: add an argument k and T_into_B
    # print('GD.k in train.py line 91',GD.k)
    # print(G) # xiaodan: print disabled by xiaodan. Too many stuffs
    # print(D)
    if config['no_Dv'] == False:
        print('Number of params in G: {} D: {} Dv: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D, Dv]
        ]))
    else:
        print('Number of params in G: {} D: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D]
        ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained BigGAN model, load weights
    if config['biggan_init']:
        print('Loading weights from pre-trained BigGAN...')
        utils.load_biggan_weights(G,
                                  D,
                                  state_dict,
                                  config['biggan_weights_root'],
                                  G_ema if config['ema'] else None,
                                  load_optim=False)

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, Dv, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    if config['dataset'] == 'C10':
        loaders = utils.get_video_cifar_data_loader(
            **{
                **config, 'batch_size': D_batch_size,
                'start_itr': state_dict['itr']
            })
    else:
        loaders = utils.get_video_data_loaders(**{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })
    # print(loaders)
    # print(loaders[0])
    print('D loss weight:', config['D_loss_weight'])
    # Prepare inception metrics: FID and IS
    if config['skip_testing'] == False:
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(
        config['G_batch_size'], config['batch_size']
    )  # * num_devices #xiaodan: num_devices added by xiaodan
    # print('num_devices:',num_devices,'G_batch_size:',G_batch_size)
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # print('z_,y_ shapes after prepare_z_y:',z_.shape,y_.shape)
    # print('z_,y_ size:',z_.shape,y_.shape)
    # print('G.dim_z:',G.dim_z)
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, Dv, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    unique_id = datetime.datetime.now().strftime('%Y%m-%d%H-%M%S-')
    tensorboard_path = os.path.join(config['logs_root'], 'tensorboard_logs',
                                    unique_id)
    os.makedirs(tensorboard_path)
    # Train for specified number of epochs, although we mostly track G iterations.
    writer = SummaryWriter(log_dir=tensorboard_path)
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        iteration = epoch * len(pbar)
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter

            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['no_Dv'] == False:
                Dv.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y, writer, iteration + i)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                if config['no_Dv'] == False:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{
                                      **utils.get_SVs(G, 'G'),
                                      **utils.get_SVs(D, 'D'),
                                      **utils.get_SVs(Dv, 'Dv')
                                  })
                else:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{
                                      **utils.get_SVs(G, 'G'),
                                      **utils.get_SVs(D, 'D')
                                  })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, Dv, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)
            #xiaodan: Disabled test for now because we don't have inception data
            # Test every specified interval
            if not (state_dict['itr'] %
                    config['test_every']) and config['skip_testing'] == False:
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                IS_mean, IS_std, FID = train_fns.test(
                    G, D, Dv, G_ema, z_, y_, state_dict, config, sample,
                    get_inception_metrics, experiment_name, test_log)
                writer.add_scalar('Inception/IS', IS_mean, iteration + i)
                writer.add_scalar('Inception/IS_std', IS_std, iteration + i)
                writer.add_scalar('Inception/FID', FID, iteration + i)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
예제 #9
0
파일: train.py 프로젝트: fagan2888/fairgen
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    # config['n_classes'] = utils.nclass_dict[config['dataset']]

    # NOTE: setting n_classes to 1 except in conditional case to train as unconditional model
    config['n_classes'] = 1
    if config['conditional']:
        config['n_classes'] = 2
    print('n classes: {}'.format(config['n_classes']))
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(
        G, D,
        config['conditional'])  # check if labels are 0's if "unconditional"
    print(G)
    print(D)
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num_fair': 0,
        'save_best_num_fid': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'best_fair_d': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.json' % (config['logs_root'],
                                             experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(
        config, **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               true_prop=config['true_prop'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()

    # NOTE: "unconditional" GAN
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])

        # iterate through the dataloaders
        for i, (x, y, ratio) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y, ratio = x.to(device).half(), y.to(device), ratio.to(
                    device)
            else:
                x, y, ratio = x.to(device), y.to(device), ratio.to(device)
            metrics = train(x, y, ratio)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

        # Test every epoch (not specified interval)
        if (epoch >= config['start_eval']):
            # First, find correct inception moments
            data_moments = '../../fid_stats/unbiased_all_gender_fid_stats.npz'
            if config['multi']:
                data_moments = '../../fid_stats/unbiased_all_multi_fid_stats.npz'
                fid_type = 'multi'
            else:
                fid_type = 'gender'

            # load appropriate moments
            print('Loaded data moments at: {}'.format(data_moments))
            experiment_name = (config['experiment_name']
                               if config['experiment_name'] else
                               utils.name_from_config(config))

            # eval mode for FID computation
            if config['G_eval_mode']:
                print('Switching G to eval mode...')
                G.eval()
                if config['ema']:
                    G_ema.eval()
            utils.sample_inception(
                G_ema if config['ema'] and config['use_ema'] else G, config,
                str(epoch))
            # Get saved sample path
            folder_number = str(epoch)
            sample_moments = '%s/%s/%s/samples.npz' % (
                config['samples_root'], experiment_name, folder_number)
            # Calculate FID
            FID = fid_score.calculate_fid_given_paths(
                [data_moments, sample_moments],
                batch_size=100,
                cuda=True,
                dims=2048)
            print("FID calculated")
            train_fns.update_FID(G, D, G_ema, state_dict, config, FID,
                                 experiment_name, test_log,
                                 epoch)  # added epoch logging
        # Increment epoch counter at end of epoch
        print('Completed epoch {}'.format(epoch))
        state_dict['epoch'] += 1
예제 #10
0
 def test_one_ema(self):
     ema = utils.ema(1.0, 2.0, 1.0)
     self.assertAlmostEqual(ema, 1.0, 6, "ema - 0.0: test failed")
예제 #11
0
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    if config['resume']:
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    utils.seed_rng(config['seed'])
    utils.prepare_root(config)
    torch.backends.cudnn.benchmark = True
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    G3 = model.Generator(**config).to(device)
    D3 = model.Discriminator(**config).to(device)
    if config['ema']:
        G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None
    if config['G_fp16']:
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        D = D.half()
    GD = model.G_D(G, D, config['conditional'])
    GD3 = model.G_D(G3, D3, config['conditional'])
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config}
    if config['resume']:
        utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume']))
    train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle'])
    utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations'])
    # Use: config['abnormal_class']
    #print(config['abnormal_class'])
    abnormal_class = config['abnormal_class']
    select_dataset = config['select_dataset']
    #print(config['select_dataset'])
    #print(select_dataset)
    loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset})
    # Usage: --select_dataset cifar10 --abnormal_class 0 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    # Use: --select_dataset mnist --abnormal_class 1 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G3, D3, GD3, G3, D3, GD3, G, D, GD, z_, y_, ema, state_dict, config)
    else:
        train = train_fns.dummy_training_function()
    sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config)
    if config['dataset'] == 'C10U' or config['dataset'] == 'C10':
        data_moments = 'fid_stats_cifar10_train.npz'
        #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
        #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
    else:
        print("Cannot find the data set.")
        sys.exit()
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            state_dict['itr'] += 1
            G.eval()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            print('')
            # Random seed
            #print(config['seed'])
            if epoch==0 and i==0:
                print(config['seed'])
            metrics = train(x, y)
            # We double the learning rate if we double the batch size.
            train_log.log(itr=int(state_dict['itr']), **metrics)
            if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
            if config['pbar'] == 'mine':
                print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ')
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name)
            experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
            if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                utils.sample_inception(
                    G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch))
                folder_number = str(epoch)
                sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number)
                FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048)
                train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log)
        state_dict['epoch'] += 1
    #utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'be01Bes01Best%d' % state_dict['save_best_num'], G_ema if config['ema'] else None)
    utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = 64
    config['n_classes'] = 120
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    # Seed RNG
    utils.seed_rng(config['seed'])
    # Prepare root folders if necessary
    utils.prepare_root(config)
    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else 'generative_dog_images')
    print('Experiment name is %s' % experiment_name)

    G = BigGAN.Generator(**config).to(device)
    D = BigGAN.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = BigGAN.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    GD = BigGAN.G_D(G, D)
    print(G)
    print(D)
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = dataset.get_data_loaders(data_root=config['data_root'],
                                       label_root=config['label_root'],
                                       batch_size=D_batch_size,
                                       num_workers=config['num_workers'],
                                       shuffle=config['shuffle'],
                                       pin_memory=config['pin_memory'],
                                       drop_last=True)

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    train = train_fns.create_train_fn(G, D, GD, z_, y_, ema, state_dict,
                                      config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    start_time = time.perf_counter()
    total_iters = config['num_epochs'] * len(loaders[0])

    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        for i, (x, y) in enumerate(loaders[0]):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            x, y = x.to(device), y.to(device)
            metrics = train(x, y)

            if not (state_dict['itr'] % config['log_interval']):
                curr_time = time.perf_counter()
                curr_time_str = datetime.datetime.fromtimestamp(
                    curr_time).strftime('%H:%M:%S')
                elapsed = str(
                    datetime.timedelta(seconds=(curr_time - start_time)))
                log = ("[{}] [{}] [{} / {}] Ep {}, ".format(
                    curr_time_str, elapsed, state_dict['itr'], total_iters,
                    epoch) + ', '.join([
                        '%s : %+4.3f' % (key, metrics[key]) for key in metrics
                    ]))
                print(log)

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switching G to eval mode...')
                    G.eval()
                    # if config['ema']:
                    # G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
예제 #13
0
 def test_half_ema(self):
     ema = utils.ema(1.0, 2.0, 0.5)
     self.assertAlmostEqual(ema, 1.5, 6, "ema [0.5] test failed")
예제 #14
0
            XU, XUy = mixmatch(model, batchX['image'], batchX['label'], batchU['image'], 0.5, 2, beta)
            logits = [model(XU[0])]
            for batch in XU[1:]:
                logits.append(model(batch))
            logits = interleave(logits, 64)
            logits_x = logits[0]
            logits_u = tf.concat(logits[1:], axis=0)

            # compute loss
            xe_loss, l2u_loss = semi_loss(XUy[:64], logits_x, XUy[64:], logits_u)
            total_loss = xe_loss + lambda_u * l2u_loss

        # compute gradients and run optimizer step
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        ema(model, ema_model, 0.999)
        weight_decay(model=model, decay_rate=0.02 * 0.002)

        xe_loss_avg(xe_loss)
        l2u_loss_avg(l2u_loss)
        total_loss_avg(total_loss)
        accuracy(tf.argmax(batchX['label'], axis=1, output_type=tf.int32), model(tf.cast(batchX['image'], dtype=tf.float32), training=False))

        # progress_bar.set_postfix({
        #     'XE Loss': f'{xe_loss_avg.result():.4f}',
        #     'L2U Loss': f'{l2u_loss_avg.result():.4f}',
        #     'WeightU': f'{lambda_u:.3f}',
        #     'Total Loss': f'{total_loss_avg.result():.4f}',
        #     'Accuracy': f'{accuracy.result():.3%}'
        # })
                if batch_num % 512 == 0:
        print hostmon.net_state
    elif data_type == "loadavg":
        print hostmon.avg_load
    elif data_type == "free_mem":
        print hostmon.mem_state['free_memory']
    elif data_type == "next_loadavg":
        """ return the next time-period average loadavg"""
        num_data = len(hostmon.predict_statistics['loadavg'])
        load = 0.0
        for t,v in hostmon.predict_statistics['loadavg']:
            load += v
        print float(load/num_data)
    elif data_type == "now_and_next_loadavg":
        """ return now and the next time-period average loadavg """
        num_data = len(hostmon.predict_statistics['loadavg'])
        load = 0.0
        for t,v in hostmon.predict_statistics['loadavg']:
            load += v
        print (float(hostmon.avg_load)+float(load/num_data))/2
    elif data_type == "predict_test":
        for param, dataset in hostmon.statistics.iteritems():
            print "[%s]" % param
            time_points = [float(t) for t,d in dataset]
            next_time_points = [float(t)+float(5) for t,d in dataset]
            data_points = [float(d) for t,d in dataset]
            pred_data_points = utils.ema([float(d) for t,d in dataset])
            print "time=%s" % time_points
            print "data=%s" % data_points
            print "next_time=%s" % next_time_points
            print "pred_data=%s" % pred_data_points
def main(
        num_workers=8,
        num_filters=32,
        dataset='cifar10',
        data_path='/tmp/data',
        output_dir='/tmp/sla',
        run_id=None,
        num_labeled=40,
        seed=1,
        num_epochs=1024,
        batches_per_epoch=1024,
        checkpoint_interval=1024,
        snapshot_interval=None,
        max_checkpoints=25,
        optimizer='sgd',
        lr=0.03,
        momentum=0.9,
        nesterov=True,
        weight_decay=5e-4,
        bn_momentum=1e-3,
        labeled_batch_size=64,
        unlabeled_batch_size=64*7,
        unlabeled_weight=1.,
        exp_moving_avg_decay=1e-3,
        allocation_schedule=((0., 1.), (0., 1.)),
        entropy_reg=100.,
        update_tol=0.01,
        labeled_aug='weak',
        unlabeled_aug=('weak', 'strong'),
        whiten=True,
        sample_mode='label_dist_min1',
        upper_bound_method='empirical',
        upper_bound_kwargs={},
        mixed_precision=True,
        devices=('cuda:0',)):

    # initial setup
    num_batches = num_epochs * batches_per_epoch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    args = dict(locals())
    logger.info(pprint.pformat(args))

    run_id = datetime.datetime.now().isoformat() if run_id is None else run_id
    output_dir = os.path.join(output_dir, str(run_id))
    logger.info('output dir = %s' % output_dir)
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)
    train_logger, eval_logger = TableLogger(), TableLogger()

    # load datasets
    if dataset == 'cifar10':
        dataset_fn = get_cifar10
    elif dataset == 'cifar100':
        dataset_fn = get_cifar100
    elif dataset == 'svhn':
        dataset_fn = get_svhn
    else:
        raise ValueError('Invalid dataset ' + dataset)
    datasets = dataset_fn(
        data_path, num_labeled, labeled_aug=labeled_aug, unlabeled_aug=unlabeled_aug,
        sample_mode=sample_mode, whiten=whiten)

    model = modules.WideResNet(
        num_classes=datasets['labeled'].num_classes, bn_momentum=bn_momentum, channels=num_filters)
    optimizer = partial(torch.optim.SGD, lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
    scheduler = partial(utils.WarmupCosineLrScheduler, warmup_iter=0, max_iter=num_batches)
    evaluator = ModelEvaluator(datasets['test'], labeled_batch_size + unlabeled_batch_size, num_workers)

    def evaluate(model, avg_model, iter):
        results = evaluator.evaluate(model, device=devices[0])
        avg_results = evaluator.evaluate(avg_model, device=devices[0])
        valid_stats = {
            'valid_loss': avg_results.log_loss,
            'valid_accuracy': avg_results.accuracy,
            'valid_loss_noavg': results.log_loss,
            'valid_accuracy_noavg': results.accuracy
        }
        eval_logger.write(
            iter=iter,
            **valid_stats)
        eval_logger.step()
        return avg_results.accuracy

    def checkpoint(model, avg_model, optimizer, scheduler, iter, fmt='ckpt-{:08d}.pt'):
        path = os.path.join(output_dir, fmt.format(iter))
        torch.save(dict(
            iter=iter,
            model=model.state_dict(),
            avg_model=avg_model.state_dict(),
            optimizer=optimizer.state_dict(),
            scheduler=scheduler.state_dict()), path)
        checkpoint_files = sorted(list(filter(lambda x: re.match(r'^ckpt-[0-9]+.pt$', x), os.listdir(output_dir))))
        if len(checkpoint_files) > max_checkpoints:
           os.remove(os.path.join(output_dir, checkpoint_files[0]))
        train_logger.to_dataframe().to_pickle(os.path.join(output_dir, 'train.log.pkl'))
        eval_logger.to_dataframe().to_pickle(os.path.join(output_dir, 'eval.log.pkl'))

    trainer = SLASelfTraining(
        num_epochs=num_epochs,
        batches_per_epoch=batches_per_epoch,
        num_workers=num_workers,
        model_optimizer_ctor=optimizer,
        lr_scheduler_ctor=scheduler,
        param_avg_ctor=partial(modules.EMA, alpha=exp_moving_avg_decay),
        labeled_batch_size=labeled_batch_size,
        unlabeled_batch_size=unlabeled_batch_size,
        unlabeled_weight=unlabeled_weight,
        allocation_schedule=utils.PiecewiseLinear(*allocation_schedule),
        entropy_reg=entropy_reg,
        update_tol=update_tol,
        upper_bound_method=upper_bound_method,
        upper_bound_kwargs=upper_bound_kwargs,
        mixed_precision=mixed_precision,
        devices=devices)

    timer = utils.Timer()
    with tqdm(desc='train', total=num_batches, position=0) as train_pbar:
        train_iter = utils.Generator(
            trainer.train_iter(model, datasets['labeled'].num_classes, datasets['labeled'], datasets['unlabeled']))
        smoothed_loss = utils.ema(0.3, avg_only=True)
        smoothed_loss.send(None)
        smoothed_acc = utils.ema(1., avg_only=False)
        smoothed_acc.send(None)
        eval_stats = None, None

        # training loop
        for i, stats in enumerate(train_iter):
            if isinstance(stats, trainer.__class__.Stats):
                train_pbar.set_postfix(
                    loss=smoothed_loss.send(stats.loss), eval_acc=eval_stats[0], eval_v=eval_stats[1], refresh=False)
                train_pbar.update()
                train_logger.write(
                    loss=stats.loss, loss_labeled=stats.loss_labeled, loss_unlabeled=stats.loss_unlabeled,
                    mean_imputed_labels=stats.label_vars.data.mean(0).cpu().numpy(),
                    scaling_vars=stats.scaling_vars.data.cpu().numpy(),
                    allocation_param=stats.allocation_param,
                    assigned_frac=stats.label_vars.data.sum(-1).mean(),
                    assignment_err=stats.assgn_err, assignment_iters=stats.assgn_iters, time=timer())

                if (checkpoint_interval is not None
                    and i > 0 and (i + 1) % checkpoint_interval == 0) or (i == num_batches - 1):
                    eval_acc = evaluate(stats.model, stats.avg_model, i+1)
                    eval_stats = smoothed_acc.send(eval_acc)
                    checkpoint(stats.model, stats.avg_model, stats.optimizer, stats.scheduler, i+1)
                    logger.info('eval acc = %.4f | allocated frac = %.4f | allocation param = %.4f' %
                                (eval_acc, stats.label_vars.mean(0).sum().cpu().item(), stats.allocation_param))
                    logger.info('assignment err = %.4e | assignment iters = %d' % (stats.assgn_err, stats.assgn_iters))
                    logger.info('batch assignments = {}'.format(stats.label_vars.mean(0).cpu().numpy()))
                    logger.info('scaling vars = {}'.format(stats.scaling_vars.cpu().numpy()))

                # take snapshots that are guaranteed to be preserved
                if snapshot_interval is not None and i > 0 and (i + 1) % snapshot_interval == 0:
                    checkpoint(stats.model, stats.avg_model, stats.optimizer,
                               stats.scheduler, i + 1, 'snapshot-{:08d}.pt')

                train_logger.step()
예제 #17
0
def run(config):
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}
                
  # Optionally, get the configuration from the state dict. This allows for
  # recovery of the config provided only a state dict and experiment name,
  # and can be convenient for writing less verbose sample shell scripts.
  if config['config_from_name']:
    utils.load_weights(None, None, state_dict, config['weights_root'], 
                       config['experiment_name'], config['load_weights'], None,
                       strict=False, load_optim=False)
    # Ignore items which we might want to overwrite from the command line
    for item in state_dict['config']:
      if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']:
        config[item] = state_dict['config'][item]
  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config = utils.update_config_roots(config)
  device = 'cuda'
  
  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  utils.count_parameters(G)
  
  # Load weights
  print('Loading weights...')
  # Here is where we deal with the ema--load ema weights or load normal weights
  utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, 
                     config['weights_root'], experiment_name, config['load_weights'],
                     G if config['ema'] and config['use_ema'] else None,
                     strict=False, load_optim=False)
  # Update batch size setting used for G
  G_batch_size = max(config['G_batch_size'], config['batch_size']) 
  z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'], 
                             z_var=config['z_var'])
  
  if config['G_eval_mode']:
    print('Putting G in eval mode..')
    G.eval()
  else:
    print('G is in %s mode...' % ('training' if G.training else 'eval'))
    
  #Sample function
  sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)  
  if config['accumulate_stats']:
    print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations'])
    utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                    config['num_standing_accumulations'])
    
  
  # Sample a number of images and save them to an NPZ, for use with TF-Inception
  if config['sample_npz']:
    # Lists to hold images and labels for images
    x, y = [], []
    print('Sampling %d images and saving them to npz...' % config['sample_num_npz'])
    for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
      with torch.no_grad():
        images, labels = sample()
      x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
      y += [labels.cpu().numpy()]
    x = np.concatenate(x, 0)[:config['sample_num_npz']]
    y = np.concatenate(y, 0)[:config['sample_num_npz']]    
    print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
    npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name)
    print('Saving npz to %s...' % npz_filename)
    np.savez(npz_filename, **{'x' : x, 'y' : y})
  
  # Prepare sample sheets
  if config['sample_sheets']:
    print('Preparing conditional sample sheets...')
    utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], 
                         num_classes=config['n_classes'], 
                         samples_per_class=10, parallel=config['parallel'],
                         samples_root=config['samples_root'], 
                         experiment_name=experiment_name,
                         folder_number=config['sample_sheet_folder_num'],
                         z_=z_,)
  # Sample interp sheets
  if config['sample_interps']:
    print('Preparing interp sheets...')
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
      utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8,
                         num_classes=config['n_classes'], 
                         parallel=config['parallel'], 
                         samples_root=config['samples_root'], 
                         experiment_name=experiment_name,
                         folder_number=config['sample_sheet_folder_num'], 
                         sheet_number=0,
                         fix_z=fix_z, fix_y=fix_y, device='cuda')
  # Sample random sheet
  if config['sample_random']:
    print('Preparing random sample sheet...')
    images, labels = sample()
    print("labels size", labels)    
    torchvision.utils.save_image(images.float(),
                                 '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name),
                                 nrow=int(G_batch_size**0.5),
                                 normalize=True)

   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  
  # FP16?
  if config['G_fp16']:
    print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    print('Casting D to fp16...')
    D = D.half()
    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  #print(G)
  #print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)
    if config['cross_replica']:
      patch_replication_callback(GD)

  
  D_fake = D(images[1,:,:,:],labels[0])
  print("D_fake ",D_fake)
예제 #18
0
testset = stockloader.StockData(True)
pred = nn.DataParallel(model.TraderAI())
# predII = nn.DataParallel(model.TraderAI(273))
optim = op.Adadelta(pred.parameters())  #  + list(predII.parameters())
sched = op.lr_scheduler.ExponentialLR(optim, 0.64)
# pred.load_state_dict(torch.load("predictorMKII-A.pt"))
# predII.load_state_dict(torch.load("networkMVI-B.pt"))
if True:
    for i in range(100):
        s = 0
        for dp in a:
            #  pred(dp)
            short = utils.tema(dp.numpy(), 16)
            longt = utils.tema(dp.numpy(), 75)

            shorte = utils.ema(dp.numpy(), 50)
            longe = utils.ema(dp.numpy(), 200)

            conv_line, base_line, lead_span_a, lead_span_b = utils.ichimoku_cloud(
                dp.numpy())

            # import matplotlib.pyplot as plt

            #plt.plot(dp.numpy())
            #plt.plot(short)
            #plt.plot(longt)
            # plt.ylabel('some numbers')

            data_stream = [
                dp.clone(), short, longt, shorte, longe, conv_line, base_line,
                lead_span_a, lead_span_b
예제 #19
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[
        config['dataset']] * config['cluster_per_class']
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    if config['is_encoder']:
        config['E_fp16'] = float(config['D_fp16'])
        config['num_E_accumulations'] = int(config['num_D_accumulations'])
        config['dataset_channel'] = utils.channel_dict[config['dataset']]
        config['lambda_encoder'] = config['resolution']**2 * config[
            'dataset_channel']

    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    if config['is_encoder']:
        E = model.Encoder(**{**config, 'D': D}).to(device)
    Prior = layers.Prior(**config).to(device)
    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
        if not config['prior_type'] == 'default':
            Prior = Prior.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    if config['is_encoder'] and config['E_fp16']:
        print('Casting E to fp16...')
        E = E.half()

    print(G)
    print(D)
    if config['is_encoder']:
        print(E)
    print(Prior)
    if not config['is_encoder']:
        GD = model.G_D(G, D)
        print('Number of params in G: {} D: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D]
        ]))
    else:
        GD = model.G_D(G, D, E, Prior)
        GE = model.G_E(G, E, Prior)
        print('Number of params in G: {} D: {} E: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D, E]
        ]))

    # Prepare state dict, which holds things like epoch # and itr #
    # ¡¡¡¡¡¡¡¡¡ Put rec error, discriminator loss and generator loss !!!!!!!!!!!?????????
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'best_error_rec': 99999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G,
            D,
            state_dict,
            config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None,
            E=None if not config['is_encoder'] else E,
            Prior=Prior if not config['prior_type'] == 'default' else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # If parallel, parallelize the GD module
    #if config['parallel'] and config['is_encoder']:
    #  GE = nn.DataParallel(GE)
    #  if config['cross_replica']:
    #    patch_replication_callback(GE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })
    if config['is_encoder']:
        config_aux = config.copy()
        config_aux['augment'] = False
        dataloader_noaug = utils.get_data_loaders(
            **{
                **config_aux, 'batch_size': D_batch_size,
                'start_itr': state_dict['itr']
            })

    # Prepare inception metrics: FID and IS
    if (config['dataset'] in ['C10']):
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])
    else:
        get_inception_metrics = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(
            G, D, GD, Prior, ema, state_dict, config,
            losses.Loss_obj(**config), None if not config['is_encoder'] else E)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        Prior=Prior,
        config=config)

    # Create fixed
    fixed_z, fixed_y = Prior.sample_noise_and_y()
    fixed_z, fixed_y = fixed_z.clone(), fixed_y.clone()
    iter_num = 0
    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['is_encoder']:
                E.train()
            if not config['prior_type'] == 'default':
                Prior.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)

            metrics = train(x, y, iter_num)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })
                if config['is_encoder']:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{**utils.get_SVs(E, 'E')})

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if not config['prior_type'] == 'default':
                        Prior.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(
                    G, D, G_ema, Prior, fixed_z, fixed_y, state_dict, config,
                    experiment_name, None if not config['is_encoder'] else E)

            if not (state_dict['itr'] %
                    config['test_every']) and config['is_encoder']:
                if not config['prior_type'] == 'default':
                    test_acc, test_acc_iter, error_rec = train_fns.test_accuracy(
                        GE, dataloader_noaug, device, config['D_fp16'], config)
                    p_mse, p_lik = train_fns.test_p_acc(GE, device, config)
                if config['n_classes'] == 10:
                    utils.reconstruction_sheet(
                        GE,
                        classes_per_sheet=utils.classes_per_sheet_dict[
                            config['dataset']],
                        num_classes=config['n_classes'],
                        samples_per_class=20,
                        parallel=config['parallel'],
                        samples_root=config['samples_root'],
                        experiment_name=experiment_name,
                        folder_number=state_dict['itr'],
                        dataloader=dataloader_noaug,
                        device=device,
                        D_fp16=config['D_fp16'],
                        config=config)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    if not config['prior_type'] == 'default':
                        Prior.eval()
                    G.eval()
                train_fns.test(
                    G, D, G_ema, Prior, state_dict, config, sample,
                    get_inception_metrics, experiment_name, test_log,
                    None if not config['is_encoder'] else E,
                    None if config['prior_type'] == 'default' else
                    (test_acc, test_acc_iter, error_rec, p_mse, p_lik))

            if not (state_dict['itr'] % config['test_every']):
                utils.create_curves(train_metrics_fname,
                                    plot_sv=False,
                                    prior_type=config['prior_type'],
                                    is_E=config['is_encoder'])
                utils.plot_IS_FID(train_metrics_fname)

        # Increment epoch counter at end of epoch
        iter_num += 1
        state_dict['epoch'] += 1
예제 #20
0
파일: trainer.py 프로젝트: zd8692/robustOT
    def __init__(self, config):
        self.config = config
        self.device = 'cuda:0'

        dset_fn = dataset_factory[config.dataset]

        # Creating dataloader
        if config.dataset == 'celeba_attribute':

            # Values should contain a tuple of (fraction, enable_flag)
            attribute_list = [('Male', 1, config.anomaly_frac),
                              ('Male', -1, 1 - config.anomaly_frac)]
            self.dataloader, self.num_classes = dset_fn(
                config.dataroot,
                config.batchSize,
                imgSize=config.imageSize,
                input_attribute_list=attribute_list,
                anomaly_frac=config.anomaly_frac,
                anomalypath=config.anopath,
                savepath=config.logdir,
                train=True)
            self.testloader, _ = dset_fn(config.dataroot,
                                         32,
                                         imgSize=config.imageSize,
                                         input_attribute_list=attribute_list,
                                         anomaly_frac=config.anomaly_frac,
                                         anomalypath=config.anopath,
                                         train=False)
        else:
            self.dataloader, self.num_classes = dset_fn(
                config.dataroot,
                config.batchSize,
                imgSize=config.imageSize,
                anomaly_frac=config.anomaly_frac,
                anomalypath=config.anopath,
                savepath=config.logdir,
                train=True)
            self.testloader, _ = dset_fn(config.dataroot,
                                         32,
                                         imgSize=config.imageSize,
                                         anomaly_frac=config.anomaly_frac,
                                         anomalypath=config.anopath,
                                         train=False)
        config.__dict__['num_classes'] = self.num_classes

        if self.num_classes == 1:
            assert not config.conditional
        # Creating models
        gen_model_fn = models.generator_factory[config.netG]
        disc_model_fn = models.discriminator_factory[config.netD]
        self.netG = gen_model_fn(config)
        self.netD = disc_model_fn(config)
        self.netG = self.netG.to(self.device)
        self.netD = self.netD.to(self.device)
        print(self.netD)
        print(self.netG)
        if self.config.ngpu > 1:
            self.netG = nn.DataParallel(self.netG)
            self.netD = nn.DataParallel(self.netD)

        print(self.netG)
        print(self.netD)

        # Creating optimizer
        self.optimizerD = optim.Adam(self.netD.parameters(),
                                     lr=config.lrD,
                                     betas=(config.beta1D, config.beta2D))
        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=config.lrG,
                                     betas=(config.beta1G, config.beta2G))

        num_iters = (len(self.dataloader.dataset) /
                     config.batchSize) * config.nepochs
        print('Running for {} discriminator iterations'.format(num_iters))
        if config.lrdecay:
            self.schedulerD = utils.LinearLR(self.optimizerD, num_iters)
            self.schedulerG = utils.LinearLR(self.optimizerG, num_iters)

        # Creating loss functions
        [self.disc_loss_fn,
         self.gen_loss_fn] = losses.loss_factory[config.loss]
        self.aux_loss_fn = losses.aux_loss

        self.epoch = 0
        self.itr = 0
        Path(self.config.logdir).mkdir(exist_ok=True, parents=True)

        if config.use_ema:
            self.ema = utils.ema(self.netG)

        self.log_path = osp.join(self.config.logdir, 'log.txt')
        if not osp.exists(self.log_path):
            fh = open(self.log_path, 'w')
            fh.write('Logging {} GAN training\n'.format(self.config.dataset))
            fh.close()

        self.inception_evaluator = inception.Evaluator(config, self.testloader)
        self.prev_gen_loss = 0
        self.best_is = 0
        self.best_fid = 100000
        self.best_is_std = 0
        self.best_intra_fid = 1000000
        self.eps = 0.0001

        # Weights
        self.rho = self.config.rho
        self.weight_update_flag = self.config.weight_update
        self.weight_update_type = self.config.weight_update_type
        if self.weight_update_flag:
            if self.weight_update_type == 'discrete':
                self.num_datapoints = len(self.dataloader.dataset)
                self.weight_vector = torch.FloatTensor(
                    self.num_datapoints, ).fill_(1).to(self.device)
                self.disc_vector = torch.FloatTensor(
                    self.num_datapoints, ).fill_(1).to(self.device)
                self.disc_vector_cur = torch.FloatTensor(
                    self.num_datapoints, ).fill_(1).to(self.device)
            else:
                weight_model_fn = models.weight_factory[config.netD]
                self.netW = weight_model_fn(config).to(self.device)
                if self.config.ngpu > 1:
                    self.netW = nn.DataParallel(self.netW)
                self.optimizerW = optim.Adam(self.netW.parameters(),
                                             lr=config.lrD,
                                             betas=(config.beta1D,
                                                    config.beta2D))
                self.weight_loss_fn = losses.loss_factory_weights[config.loss]

        # Code for restoring models from checkpoint
        if config.restore != '':
            self.restore_state(config.restore)

        # Checking if state exists (in case of preemption)
        if os.path.exists(osp.join(self.config.logdir, 'model_state.pth')):
            self.restore_state(osp.join(self.config.logdir, 'model_state.pth'))
예제 #21
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cpu'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    experiment_name = "test_{}".format(experiment_name)
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)
    # print(G)
    # print(D)
    # print(E)
    print("Model Created!")
    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights

    print('Loading weights...')
    utils.load_weights(
        G, D, E, state_dict, config['weights_root'],
        config['load_experiment_name'],
        config['load_weights'] if config['load_weights'] else None,
        G_ema if config['ema'] else None)
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }
    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': 0
    })

    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    G.eval()
    E.eval()
    print("check1 -------------------------------")
    print("state_dict['itr']", state_dict['itr'])
    if config['pbar'] == 'mine':
        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')

    else:
        pbar = tqdm(loaders[0])

    print("state_dict['itr']", state_dict['itr'])
    for i, (x, y) in enumerate(pbar):
        state_dict['itr'] += 1
        if config['D_fp16']:
            x, y = x.to(device).half(), y.to(device)
        else:
            x, y = x.to(device), y.to(device)
        print("x.shape", x.shape)
        print("y.shape", y.shape)

        activation_extract(G,
                           D,
                           E,
                           G_ema,
                           x,
                           y,
                           z_,
                           y_,
                           state_dict,
                           config,
                           experiment_name,
                           save_weights=False)
        if state_dict['itr'] == 20:
            break
# Pathint protocol describes a set of kwargs for the SynapticOptimizer (credit: Zenke et. al, 2017)
PATH_INT_PROTOCOL = lambda omega_decay, xi: (
    'path_int[omega_decay=%s,xi=%s]' % (omega_decay, xi),
    {
        'init_updates': [
            ('cweights', lambda vars, w, prev_val: w.value()),
        ],
        'step_updates': [
            ('grads2', lambda vars, w, prev_val: prev_val - vars['unreg_grads']
             [w] * vars['deltas'][w]),
        ],
        'task_updates': [
            ('omega', lambda vars, w, prev_val: tf.nn.relu(
                ema(
                    omega_decay, prev_val, vars['grads2'][w] /
                    ((vars['cweights'][w] - w.value())**2 + xi)))),
            #('cached_grads2', lambda vars, w, prev_val: vars['grads2'][w]),
            #('cached_cweights', lambda vars, w, prev_val: vars['cweights'][w]),
            ('cweights', lambda opt, w, prev_val: w.value()),
            ('grads2', lambda vars, w, prev_val: prev_val * 0.0),
        ],
        'regularizer_fn':
        importancePenalty,
    })

# Optimization parameters
img_rows, img_cols = 28, 28
train_size = 50000
valid_size = 10000
test_size = 10000
예제 #23
0
 def test_zero_ema(self):
     ema = utils.ema(1.0, 2.0, 0.0)
     self.assertAlmostEqual(ema, 2.0, 6, "ema - 1.0: test failed")
예제 #24
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, E, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # # Prepare inception metrics: FID and IS
    # get_inception_metrics = inception_utils.prepare_inception_metrics(
    #     config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE, ema, state_dict,
                                                config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning training at Epoch {} (iteration {})".format(
        state_dict['epoch'], state_dict['itr']))
    # # Train for specified number of epochs, although we mostly track G iterations.
    # for epoch in range(state_dict['epoch'], config['num_epochs']):
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
        pbar = tqdm(loaders[0])

    for i, (x, y) in enumerate(pbar):
        # Increment the iteration counter
        state_dict['itr'] += 1
        # Make sure G and D are in training mode, just in case they got set to eval
        # For D, which typically doesn't have BN, this shouldn't matter much.
        G.eval()
        D.eval()
        if config['ema']:
            G_ema.eval()
        if config['D_fp16']:
            x, y = x.to(device).half(), y.to(device)
        else:
            x, y = x.to(device), y.to(device)

        # Every sv_log_interval, log singular values
        if (config['sv_log_interval'] >
                0) and (not (state_dict['itr'] % config['sv_log_interval'])):
            train_log.log(itr=int(state_dict['itr']),
                          **{
                              **utils.get_SVs(G, 'G'),
                              **utils.get_SVs(D, 'D')
                          })

        # If using my progbar, print metrics.
        if config['pbar'] == 'mine':
            print(', '.join(
                ['itr: %d' % state_dict['itr']] +
                ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                  end=' ')

        # Save weights and copies as configured at specified interval
        if (not state_dict['itr'] % config['save_img_every']) or (
                not state_dict['itr'] % config['save_model_every']):
            if config['G_eval_mode']:
                print('Switchin G to eval mode...')
                G.eval()
                if config['ema']:
                    G_ema.eval()
            save_weights = config['save_weights']
            if state_dict['itr'] % config['save_model_every']:
                save_weights = False
            train_fns.save_and_sample(G,
                                      D,
                                      E,
                                      G_ema,
                                      fixed_x,
                                      fixed_y_of_x,
                                      z_,
                                      y_,
                                      state_dict,
                                      config,
                                      experiment_name,
                                      img_pool,
                                      save_weights=save_weights)

        # # Test every specified interval
        # if not (state_dict['itr'] % config['test_every']):
        #     if config['G_eval_mode']:
        #         print('Switchin G to eval mode...')
        #         G.eval()
        #     train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
        #                    get_inception_metrics, experiment_name, test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] = state_dict['itr'] * D_batch_size / (
            len(train_dataset))
        print("Finished Epoch {} (iteration {})".format(
            state_dict['epoch'], state_dict['itr']))
예제 #25
0
def run(config):
    logger = logging.getLogger('tl')
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = importlib.import_module(config['model'])
    # model = __import__(config['model'])
    experiment_name = 'exp'
    # experiment_name = (config['experiment_name'] if config['experiment_name']
    #                      else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config, cfg=getattr(global_cfg, 'generator',
                                              None)).to(device)
    D = model.Discriminator(**config,
                            cfg=getattr(global_cfg, 'discriminator',
                                        None)).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        },
                                cfg=getattr(global_cfg, 'generator',
                                            None)).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    logger.info(G)
    logger.info(D)
    logger.info('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(G=G,
                           D=D,
                           state_dict=state_dict,
                           weights_root=global_cfg.resume_cfg.weights_root,
                           experiment_name='',
                           name_suffix=config['load_weights']
                           if config['load_weights'] else None,
                           G_ema=G_ema if config['ema'] else None)
        logger.info(f"Resume IS={state_dict['best_IS']}")
        logger.info(f"Resume FID={state_dict['best_FID']}")

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr'],
            **getattr(global_cfg, 'train_dataloader', {})
        })

    val_loaders = None
    if hasattr(global_cfg, 'val_dataloader'):
        val_loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                **global_cfg.val_dataloader
            })[0]
        val_loaders = iter(val_loaders)
    # Prepare inception metrics: FID and IS
    if global_cfg.get('use_unofficial_FID', False):
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['inception_file'], config['parallel'], config['no_fid'])
    else:
        get_inception_metrics = inception_utils.prepare_FID_IS(global_cfg)
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config,
                                                val_loaders)
    # Else, assume debugging and use the dummy train fn
    elif config['which_train_fn'] == 'dummy':
        train = train_fns.dummy_training_function()
    else:
        train_fns_module = importlib.import_module(config['which_train_fn'])
        train = train_fns_module.GAN_training_function(G, D, GD, z_, y_, ema,
                                                       state_dict, config,
                                                       val_loaders)

    # Prepare Sample function for use with inception metrics
    if global_cfg.get('use_unofficial_FID', False):
        sample = functools.partial(
            utils.sample,
            G=(G_ema if config['ema'] and config['use_ema'] else G),
            z_=z_,
            y_=y_,
            config=config)
    else:
        sample = functools.partial(
            utils.sample_imgs,
            G=(G_ema if config['ema'] and config['use_ema'] else G),
            z_=z_,
            y_=y_,
            config=config)

    state_dict['shown_images'] = state_dict['itr'] * D_batch_size

    if global_cfg.get('resume_cfg', {}).get('eval', False):
        logger.info(f'Evaluating model.')
        G_ema.eval()
        G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)
        return

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  desc=f'Epoch:{epoch}, Itr: ',
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)

            default_dict = train(x, y)

            state_dict['shown_images'] += D_batch_size

            metrics = default_dict['D_loss']
            train_log.log(itr=int(state_dict['itr']), **metrics)

            summary_defaultdict2txtfig(default_dict=default_dict,
                                       prefix='train',
                                       step=state_dict['shown_images'],
                                       textlogger=textlogger)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ',
                      flush=True)

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if state_dict['itr'] == 1 or \
                  (config['test_every'] > 0 and state_dict['itr'] % config['test_every'] == 0) or \
                  (state_dict['shown_images'] % global_cfg.get('test_every_images', float('inf'))) < D_batch_size:
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...', flush=True)
                    G.eval()
                print('\n' + config['tl_outdir'])
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
예제 #26
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D, E]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(G, D, E, state_dict,
                           config['weights_root'], experiment_name,
                           config['load_weights'] if config['load_weights'] else None,
                           G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                        'start_itr': state_dict['itr']})

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                               device=device, fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                         config['n_classes'], device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset, experiment_name, config, device=device)
    

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE,
                                                ema, state_dict, config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(utils.sample,
                               G=(G_ema if config['ema'] and config['use_ema']
                                   else G),
                               z_=z_, y_=y_, config=config)




    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning testing at Epoch {} (iteration {})".format(state_dict['epoch'], state_dict['itr']))

    if config['G_eval_mode']:
        print('Switchin G to eval mode...')
        G.eval()
        if config['ema']:
            G_ema.eval()
    # vc visualization
    # # print("VC visualization ===============")
    # activation_extract(G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_,
    #                             state_dict, config, experiment_name, device, normal_eval=False, eval_vc=True, return_mask=False)
    # normal activation
    print("Normal activation ===============")
    activation_extract(G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_,
                                state_dict, config, experiment_name, device, normal_eval=True, eval_vc=False, return_mask=False) # produce normal fully activated images
예제 #27
0
def run(config):

    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    model = BigGAN

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it (Earth Moving Averaging for parameters)
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        ema = None

    GD = model.G_D(G, D)

    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))

    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict,
                                            config)

    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):

        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)

            print(', '.join(
                ['itr: %d' % state_dict['itr']] +
                ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                  end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
예제 #28
0
def run(config):
  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]



  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True



  config = utils.update_config_roots(config)
  device = 'cuda'
  
  # Seed RNG
  utils.seed_rng(config['seed'])


  # Prepare root folders if necessary
  utils.prepare_root(config)



  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  
   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  

    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  print(G)
  print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([np.prod(p.shape) for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)



  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  #test_log = utils.MetricsLogger(test_metrics_fname, 
  #                               reinitialize=(not config['resume']))
  test_log=LogWriter(logdir='%s/%s_log' % (config['logs_root'],
                                            experiment_name))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr']})

  # Prepare inception metrics: FID and IS
  get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'])
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                       config['n_classes'], device=device,
                                       fp16=config['G_fp16'])  
  fixed_z.sample_()
  fixed_y.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  sample = functools.partial(utils.sample,
                              G=(G_ema if config['ema'] and config['use_ema']
                                 else G),
                              z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  # Train for specified number of epochs, although we mostly track G iterations.
  for epoch in range(state_dict['epoch'], config['num_epochs']):    
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
      pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
      pbar = tqdm(loaders[0])
    for i, (x, y) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      x, y=x, y.astype(np.int64) ## special handling for paddle dataloader
      if config['ema']:
        G_ema.train()

      metrics = train(x, y)
      train_log.log(itr=int(state_dict['itr']), **metrics)

      for tag in metrics:
        try:
          test_log.add_scalar(step=int(state_dict['itr']),tag="train/"+tag,value=float(metrics[tag]))
        except:
          pass

      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # If using my progbar, print metrics.
      if config['pbar'] == 'mine':
          print(', '.join(['itr: %d' % state_dict['itr']] 
                           + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]), end=' ')
      else:
          pbar.set_description(', '.join(['itr: %d' % state_dict['itr']] 
                           + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]))

      # Save weights and copies as configured at specified interval
      if not (state_dict['itr'] % config['save_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, 
                                  state_dict, config, experiment_name)

      # Test every specified interval
      if not (state_dict['itr'] % config['test_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
예제 #29
0
def run(config):
    def len_parallelloader(self):
        return len(self._loader._loader)
    pl.PerDeviceLoader.__len__ = len_parallelloader

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        xm.master_print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different
    # files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    xm.master_print('Experiment name is %s' % experiment_name)

    device = xm.xla_device(devkind='TPU')

    # Next, build the model
    G = model.Generator(**config)
    D = model.Discriminator(**config)

    # If using EMA, prepare it
    if config['ema']:
        xm.master_print(
            'Preparing EMA for G with decay of {}'.format(
                config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True})
    else:
        xm.master_print('Not using ema...')
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        xm.master_print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        xm.master_print('Casting D to fp16...')
        D = D.half()

    # Prepare state dict, which holds things like itr #
    state_dict = {'itr': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        xm.master_print('Loading weights...')
        utils.load_weights(
            G,
            D,
            state_dict,
            config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # move everything to TPU
    G = G.to(device)
    D = D.to(device)

    G.optim = optim.Adam(params=G.parameters(), lr=G.lr,
                         betas=(G.B1, G.B2), weight_decay=0,
                         eps=G.adam_eps)
    D.optim = optim.Adam(params=D.parameters(), lr=D.lr,
                         betas=(D.B1, D.B2), weight_decay=0,
                         eps=D.adam_eps)

    # for key, val in G.optim.state.items():
    #  G.optim.state[key]['exp_avg'] = G.optim.state[key]['exp_avg'].to(device)
    #  G.optim.state[key]['exp_avg_sq'] = G.optim.state[key]['exp_avg_sq'].to(device)

    # for key, val in D.optim.state.items():
    #  D.optim.state[key]['exp_avg'] = D.optim.state[key]['exp_avg'].to(device)
    #  D.optim.state[key]['exp_avg_sq'] = D.optim.state[key]['exp_avg_sq'].to(device)

    if config['ema']:
        G_ema = G_ema.to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])

    # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    xm.master_print(G)
    xm.master_print(D)
    xm.master_print('Number of params in G: {} D: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]]))

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    xm.master_print(
        'Test Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    xm.master_print(
        'Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])

    if xm.is_master_ordinal():
            # Write metadata
        utils.write_metadata(
            config['logs_root'],
            experiment_name,
            config,
            state_dict)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    xm.master_print('Preparing data...')
    loader = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                       'start_itr': state_dict['itr']})

    # Prepare inception metrics: FID and IS
    xm.master_print('Preparing metrics...')

    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'],
        no_inception=config['no_inception'],
        no_fid=config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])

    def sample(): return utils.prepare_z_y(G_batch_size, G.dim_z,
                                           config['n_classes'], device=device,
                                           fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout
    # training
    fixed_z, fixed_y = sample()

    train = train_fns.GAN_training_function(G, D, GD, sample, ema, state_dict,
                                            config)

    xm.master_print('Beginning training...')

    if xm.is_master_ordinal():
        pbar = tqdm(total=config['total_steps'])
        pbar.n = state_dict['itr']
        pbar.refresh()

    xm.rendezvous('training_starts')
    while (state_dict['itr'] < config['total_steps']):
        pl_loader = pl.ParallelLoader(
            loader, [device]).per_device_loader(device)

        for i, (x, y) in enumerate(pl_loader):
            if xm.is_master_ordinal():
                # Increment the iteration counter
                pbar.update(1)

            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter
            # much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()

            xm.rendezvous('data_collection')
            metrics = train(x, y)

            # train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if ((config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval']))) :
                if xm.is_master_ordinal():
                    train_log.log(itr=int(state_dict['itr']),
                        **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
                xm.rendezvous('Log SVs.')

            # Save weights and copies as configured at specified interval
            if (not (state_dict['itr'] % config['save_every'])):
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(
                    G,
                    D,
                    G_ema,
                    sample,
                    fixed_z,
                    fixed_y,
                    state_dict,
                    config,
                    experiment_name)

            # Test every specified interval
            if (not (state_dict['itr'] % config['test_every'])):

                which_G = G_ema if config['ema'] and config['use_ema'] else G
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    which_G.eval()

                def G_sample():
                    z, y = sample()
                    return which_G(z, which_G.shared(y))

                train_fns.test(
                    G,
                    D,
                    G_ema,
                    sample,
                    state_dict,
                    config,
                    G_sample,
                    get_inception_metrics,
                    experiment_name,
                    test_log)
            
            # Debug : Message print
            # if True:
            #     xm.master_print(met.metrics_report())

            if state_dict['itr'] >= config['total_steps']:
                break