示例#1
0
def main() -> None:
    args = parse_arguments()

    u.set_gpu(args.gpu)

    # create output folder and save arguments in a .txt file
    outpath = os.path.join(
        './results/',
        args.outdir if args.outdir is not None else u.random_code())
    os.makedirs(outpath, exist_ok=True)
    print(colored('Saving to %s' % outpath, 'yellow'))
    u.write_args(os.path.join(outpath, 'args.txt'), args)

    # get a list of patches organized as dictionaries with image, mask and name fields
    patches = extract_patches(args)

    print(colored('Processing %d patches' % len(patches), 'yellow'))

    # instantiate a trainer
    T = Training(args, outpath)

    # interpolation
    for i, patch in enumerate(patches):

        print(
            colored('\nThe data shape is %s' % str(patch['image'].shape),
                    'cyan'))

        std = T.load_data(patch)
        print(colored('the std of coarse data is %.2e, ' % std, 'cyan'),
              end="")

        if np.isclose(std, 0., atol=1e-12):  # all the data are corrupted
            print(colored('skipping...', 'cyan'))
            T.out_best = T.img * T.mask
            T.elapsed = 0.
        else:
            # TODO add the transfer learning option
            if i == 0 or (args.start_from_prev and T.net is None):
                T.build_model()
            T.build_input()
            T.optimize()

        T.save_result()
        T.clean()

    print(colored('Interpolation done! Saved to %s' % outpath, 'yellow'))
示例#2
0
    mvnet = MVNet(vmin=-0.5,
                  vmax=0.5,
                  vox_bs=args.batch_size,
                  im_bs=args.im_batch,
                  grid_size=args.nvox,
                  im_h=args.im_h,
                  im_w=args.im_w,
                  norm=args.norm,
                  mode="TRAIN")

    # Define graph
    mvnet = model_dlsm(mvnet,
                       im_nets[args.im_net],
                       grid_nets[args.grid_net],
                       conv_rnns[args.rnn],
                       im_skip=args.im_skip,
                       ray_samples=args.ray_samples,
                       sepup=args.sepup,
                       proj_x=args.proj_x)

    # Set things up
    mkdir_p(log_dir)
    write_args(args, osp.join(log_dir, 'args.json'))
    logger.info('Logging to {:s}'.format(log_dir))
    logger.info('\nUsing args:')
    pprint(vars(args))
    mvnet.print_net()

    train(mvnet)
示例#3
0
def train_model(model_class,
                run_func,
                args,
                quiet=False,
                splits=None,
                abs_output_dir=False):
    output_dir = args.output_dir

    val_stat = args.val_stat
    # Keeps track of certain stats for all the data splits
    all_stats = {
        'val_%s' % val_stat: [],
        'test_%s' % val_stat: [],
        'best_epoch': [],
        'train_last': [],
        'train_best': [],
        'nce': [],
    }

    # Iterate over splits
    splits_iter = splits if splits is not None else range(args.n_splits)
    # Iterates through each split of the data
    for split_idx in splits_iter:
        # print('Training split idx: %d' % split_idx)

        # Creates the output directory for the run of the current split
        if not abs_output_dir:
            args.output_dir = output_dir + '/run_%d' % split_idx
        args.model_dir = args.output_dir + '/models'
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        if not os.path.exists(args.model_dir):
            os.makedirs(args.model_dir)
        write_args(args)

        # Create model and optimizer
        model = model_class(args)
        model.to(args.device)

        if args.separate_lr:
            optim = model.get_model_optim()
        else:
            optim = torch.optim.Adam(model.parameters(), lr=args.lr)

        if split_idx == 0:
            # Print the number of parameters
            num_params = get_num_params(model)
            if not quiet:
                print('Initialized model with %d params' % num_params)

        # Load the train, val, test data
        dataset_loaders = {}
        for data_type in ['train', 'val', 'test']:
            dataset_loaders[data_type] = get_loader(
                args.data_dir,
                data_type=data_type,
                batch_size=args.batch_size,
                shuffle=data_type == 'train',
                split=split_idx,
                n_labels=args.n_labels)

        # Keeps track of stats across all the epochs
        train_m, val_m = StatsManager(), StatsManager()

        # Tensorboard logging, only for the first run split
        if args.log_tb and split_idx == 0:
            log_dir = output_dir + '/logs'
            tb_writer = SummaryWriter(log_dir, max_queue=1, flush_secs=60)
            log_tensorboard(tb_writer, {'params': num_params}, '', 0)
        else:
            args.log_tb = False

        # Training loop
        args.latest_train_stat = 0
        args.latest_val_stat = 0  # Keeps track of the latest relevant stat
        patience_idx = 0
        for epoch_idx in range(args.n_epochs):
            args.epoch = epoch_idx
            train_stats = run_func(model=model,
                                   optim=optim,
                                   data_loader=dataset_loaders['train'],
                                   data_type='train',
                                   args=args,
                                   write_path=None,
                                   quiet=quiet)
            should_write = epoch_idx % args.write_every == 0
            val_stats = run_func(
                model=model,
                optim=None,
                data_loader=dataset_loaders['val'],
                data_type='val',
                args=args,
                write_path='%s/val_output_%d.jsonl' %
                (args.output_dir, epoch_idx) if should_write else None,
                quiet=quiet)

            if not quiet:
                train_stats.print_stats('Train %d: ' % epoch_idx)
                val_stats.print_stats('Val   %d: ' % epoch_idx)

            if args.log_tb:
                log_tensorboard(tb_writer, train_stats.get_stats(), 'train',
                                epoch_idx)
                log_tensorboard(tb_writer, val_stats.get_stats(), 'val',
                                epoch_idx)

            train_stats.add_stat('epoch', epoch_idx)
            val_stats.add_stat('epoch', epoch_idx)

            train_m.add_stats(train_stats.get_stats())
            val_m.add_stats(val_stats.get_stats())

            if val_stats.get_stats()[val_stat] == min(val_m.stats[val_stat]):
                save_model(model,
                           args,
                           args.model_dir,
                           epoch_idx,
                           should_print=not quiet)
                patience_idx = 0
            else:
                patience_idx += 1
                if args.patience != -1 and patience_idx >= args.patience:
                    print(
                        'Validation error has not improved in %d, stopping at epoch: %d'
                        % (args.patience, args.epoch))
                    break

            # Keep track of the latest epoch stats
            args.latest_train_stat = train_stats.get_stats()[val_stat]
            args.latest_val_stat = val_stats.get_stats()[val_stat]

        # Load and save the best model
        best_epoch = val_m.get_best_epoch_for_stat(args.val_stat)
        best_model_path = '%s/model_%d' % (args.model_dir, best_epoch)
        model, _ = load_model(best_model_path,
                              model_class=model_class,
                              device=args.device)
        if not quiet:
            print('Loading model from %s' % best_model_path)

        save_model(model, args, args.model_dir, 'best', should_print=not quiet)

        # Test model
        test_stats = run_func(model=model,
                              optim=None,
                              data_loader=dataset_loaders['test'],
                              data_type='test',
                              args=args,
                              write_path='%s/test_output.jsonl' %
                              args.output_dir,
                              quiet=quiet)
        if not quiet:
            test_stats.print_stats('Test: ')

        if args.log_tb:
            log_tensorboard(tb_writer, test_stats.get_stats(), 'test', 0)
            tb_writer.close()

        # Write test output to a summary file
        with open('%s/summary.txt' % args.output_dir, 'w+') as summary_file:
            for k, v in test_stats.get_stats().items():
                summary_file.write('%s: %.3f\n' % (k, v))

        # Aggregate relevant stats
        all_stats['val_%s' % val_stat].append(min(val_m.stats[val_stat]))
        all_stats['test_%s' % val_stat].append(
            test_stats.get_stats()[val_stat])
        all_stats['best_epoch'].append(best_epoch)

        all_stats['train_last'].append(train_m.stats[val_stat][-1])
        all_stats['train_best'].append(train_m.stats[val_stat][best_epoch])

        if args.nce_coef > 0:
            all_stats['nce'].append(train_m.stats['nce_reg'][best_epoch])

    # Write the stats aggregated across all splits
    with open('%s/summary.txt' % (output_dir), 'w+') as summary_file:
        summary_file.write('Num epochs trained: %d\n' % args.epoch)
        for name, stats_arr in all_stats.items():
            if stats_arr == []:
                continue
            stats_arr = np.array(stats_arr)
            stats_mean = np.mean(stats_arr)
            stats_std = np.std(stats_arr)
            summary_file.write('%s: %s, mean: %.3f, std: %.3f\n' %
                               (name, str(stats_arr), stats_mean, stats_std))

    all_val_stats = np.array(all_stats['val_%s' % val_stat])
    all_test_stats = np.array(all_stats['test_%s' % val_stat])

    val_mean, val_std = np.mean(all_val_stats), np.std(all_val_stats)
    test_mean, test_std = np.mean(all_test_stats), np.std(all_val_stats)

    train_last = np.mean(np.array(all_stats['train_last']))
    train_best = np.mean(np.array(all_stats['train_best']))

    if args.nce_coef > 0:
        nce_loss = np.mean(np.array(all_stats['nce']))
    else:
        nce_loss = 0

    # Return stats
    return (val_mean, val_std), (test_mean, test_std), (train_last,
                                                        train_best), nce_loss
示例#4
0
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-cuda',
                        action='store_true',
                        default=False,
                        help='Use gpu')

    # Data/Output Directories Params
    parser.add_argument('-data', required=True, help='Task, see above')
    parser.add_argument('-output_dir', default='', help='Output directory')
    parser.add_argument('-log_tb',
                        action='store_true',
                        default=False,
                        help='Use tensorboard to log')
    parser.add_argument('-write_every',
                        type=int,
                        default=20,
                        help='Write val results every this many epochs')

    # Pre-trained Model Params
    parser.add_argument('-pretrain_gcn',
                        type=str,
                        default=None,
                        help='path to pretrained gcn to use in another model')
    parser.add_argument('-pretrain_model',
                        type=str,
                        default=None,
                        help='path to pretrained model to load')

    # General Model Params
    parser.add_argument('-n_splits',
                        type=int,
                        default=1,
                        help='Number of data splits to train on')
    parser.add_argument('-n_epochs',
                        type=int,
                        default=100,
                        help='Number of epochs to train on')
    parser.add_argument('-lr',
                        type=float,
                        default=1e-3,
                        help='Static learning rate of the optimizer')
    parser.add_argument('-separate_lr',
                        action='store_true',
                        default=False,
                        help='Whether to use different lr for pc')
    parser.add_argument('-lr_pc',
                        type=float,
                        default=1e-2,
                        help='The learning rate for point clouds')
    parser.add_argument('-pc_xavier_std', type=float, default=0.1)
    parser.add_argument('-batch_size',
                        type=int,
                        default=48,
                        help='Number of examples in each batch')
    parser.add_argument('-max_grad_norm',
                        type=float,
                        default=10,
                        help='Clip gradients with higher norm')
    parser.add_argument('-patience',
                        type=int,
                        default=-1,
                        help='Stop training if not improved for this many')

    # GCN Params
    parser.add_argument('-n_layers',
                        type=int,
                        default=5,
                        help='Number of layers in model')
    parser.add_argument('-n_hidden',
                        type=int,
                        default=128,
                        help='Size of hidden dimension for model')
    parser.add_argument('-n_ffn_hidden', type=int, default=100)
    parser.add_argument('-dropout_gcn',
                        type=float,
                        default=0.,
                        help='Amount of dropout for the model')
    parser.add_argument('-dropout_ffn',
                        type=float,
                        default=0.,
                        help='Dropout for final ffn layer')
    parser.add_argument('-agg_func',
                        type=str,
                        choices=['sum', 'mean'],
                        default='sum',
                        help='aggregator function for atoms')
    parser.add_argument('-batch_norm',
                        action='store_true',
                        default=False,
                        help='Whether or not to normalize atom embeds')

    # Prototype Params
    parser.add_argument('-init_method',
                        default='none',
                        choices=['none', 'various', 'data'])
    parser.add_argument('-distance_metric',
                        type=str,
                        default='wasserstein',
                        choices=['l2', 'wasserstein', 'dot'])
    parser.add_argument('-n_pc',
                        type=int,
                        default=2,
                        help='Number of point clouds')
    parser.add_argument('-pc_size',
                        type=int,
                        default=20,
                        help='Number of points in each point cloud')
    parser.add_argument(
        '-pc_hidden',
        type=int,
        default=-1,
        help='Hidden dim for point clouds, different from GCN hidden dim')
    parser.add_argument('-pc_free_epoch',
                        type=int,
                        default=0,
                        help='If intialized with data, when to free pc')
    parser.add_argument('-ffn_activation',
                        type=str,
                        choices=['ReLU', 'LeakyReLU'],
                        default='LeakyReLU')
    parser.add_argument('-mult_num_atoms',
                        action='store_true',
                        default=True,
                        help='Whether to multiply the dist by number of atoms')

    # OT Params
    parser.add_argument('-opt_method',
                        type=str,
                        default='sinkhorn_stabilized',
                        choices=[
                            'sinkhorn', 'sinkhorn_stabilized', 'emd',
                            'greenkhorn', 'sinkhorn_epsilon_scaling'
                        ])
    parser.add_argument('-cost_distance',
                        type=str,
                        choices=['l2', 'dot'],
                        default='l2',
                        help='Distance computed for cost matrix')
    parser.add_argument('-sinkhorn_entropy',
                        type=float,
                        default=1e-1,
                        help='Entropy regularization term for sinkhorn')
    parser.add_argument('-sinkhorn_max_it',
                        type=int,
                        default=1000,
                        help='Max num it for sinkhorn')
    parser.add_argument('-unbalanced', action='store_true', default=False)
    parser.add_argument('-nce_coef', type=float, default=0.)

    # Plot Params
    parser.add_argument('-plot_pc',
                        action='store_true',
                        default=False,
                        help='Whether to plot the point clouds')
    parser.add_argument('-plot_num_ex',
                        type=int,
                        default=5,
                        help='Number of molecule examples to plot')
    parser.add_argument('-plot_freq',
                        type=int,
                        default=10,
                        help='Frequency of plotting')
    parser.add_argument('-plot_max',
                        type=int,
                        default=1000,
                        help='Maximum number of plots to make')

    args = parser.parse_args()
    args.device = 'cuda:0' if args.cuda else 'cpu'

    # Add path to data dir
    assert args.data in DATA_TASK
    # task specifies the kind of data
    args.task = DATA_TASK[args.data]

    # get the number of labels from the dataset, split by commas
    args.n_labels = len(args.task.split(','))

    if args.n_labels == 1:
        # val_stat is the stat to select the best model
        args.val_stat = args.task
    else:
        # if multiple labels, use a "multi-objective," some average of individual objectives
        args.val_stat = 'multi_obj'
        args.label_task_list = args.task.split(',')

    args.data_dir = 'data/%s' % args.data

    # hacky way to create the output directory initally
    if '/' in args.output_dir:
        base_output_dir = args.output_dir.split('/')[0]
        if not os.path.exists(base_output_dir):
            os.makedirs(base_output_dir)

    if args.output_dir != '' and not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.output_dir != '':
        write_args(args)
    return args
示例#5
0
def main():

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    system_type = args.system_type
    noiseless_dataset_index = args.noiseless_dataset_index
    train_noise_level = args.train_noise_level
    test_noise_level = args.test_noise_level
    run_index = args.run_index
    batch_size = args.batch_size
    n_epochs = args.n_epochs
    T = args.T
    dt = args.dt
    lr = args.lr
    n_layers = args.n_layers
    n_hidden = args.n_hidden
    n_samples = args.n_samples
    n_test_samples = args.n_test_samples
    n_val_samples = args.n_val_samples
    T_test = args.T_test
    shorten = args.shorten
    T_short = args.T_short
    T_total = T + T_test
    T_init_seq = args.T_init_seq
    max_iters_init_train = args.max_iters_init_train
    max_iters_init_test = args.max_iters_init_test
    scheduler_type = args.scheduler_type
    scheduler_patience = args.scheduler_patience
    scheduler_factor = args.scheduler_factor
    coarsening_factor_test = args.coarsening_factor_test
    test_freq = args.test_freq

    if (train_noise_level > 0):
        dataset_index = noiseless_dataset_index + '_n' + str(train_noise_level)
    else:
        dataset_index = noiseless_dataset_index
    if (test_noise_level > 0):
        test_dataset_index = noiseless_dataset_index + '_n' + str(test_noise_level)
    else:
        test_dataset_index = noiseless_dataset_index

    data_dir = './data/' + system_type
    model_dir = './models/' + system_type + '_' + str(dataset_index) + str(run_index)
    pred_dir = './predictions/' + system_type + '_' + str(dataset_index) + str(run_index)
    log_dir_together = './logs/' + system_type + '_' + str(dataset_index) + str(run_index)

    if (not os.path.isdir(data_dir)):
        os.mkdir(data_dir)
    if (not os.path.isdir(model_dir)):
        os.mkdir(model_dir)
    if (not os.path.isdir(pred_dir)):
        os.mkdir(pred_dir)
    if (not os.path.isdir(log_dir_together)):
        os.mkdir(log_dir_together)

    print ('dataset', dataset_index)
    print ('run', run_index)

    write_args(args, os.path.join(log_dir_together))
    print(vars(args))

    train_data_npy = np.load(data_dir + '/train_data_' + system_type + '_' + dataset_index + '.npy')
    test_data_npy = np.load(data_dir + '/test_data_' + system_type + '_' + test_dataset_index  + '.npy')

    train_data = torch.from_numpy(train_data_npy[:, :n_samples, :])
    test_data = torch.from_numpy(test_data_npy[:, :n_test_samples, :])

    if (system_type == '3body'):
        if (dt == 0.1):
            print ('coarsening to  dt=0.1')
            train_data = train_data[np.arange(T) * 10, :, :]
            test_data = test_data[np.arange(T_test) * 10, :, :]
        elif (dt == 1):
            print ('coarsening to  dt=0.1')
            train_data = train_data[np.arange(T) * 100, :, :]
            test_data = test_data[np.arange(T_test) * 100, :, :]
    else:
        train_data = train_data[:T, :, :]
        test_data = test_data[:T_test, :, :]
    
    ## augmenting the number of training trajectories while shortening their lengths
    if (shorten == 1):
        train_data_shortened = torch.zeros(T_short, n_samples * int(T / T_short), train_data_npy.shape[2])
        for i in range(int(T / T_short)):
            train_data_shortened[:, i * n_samples : (i+1) * n_samples, :] = train_data[i * (T_short) : (i+1) * T_short, :n_samples, :]
        train_data = train_data_shortened
        T = T_short
        n_samples = train_data_shortened.shape[1]
    elif(shorten == 2):
        train_data_shortened = torch.zeros(max(T_short, T_init_seq), n_samples * (T - max(T_short, T_init_seq) + 1), train_data_npy.shape[2])
        for i in range(T - max(T_short, T_init_seq) + 1):
            train_data_shortened[:, i * n_samples : (i+1) * n_samples, :] = train_data[i : i + max(T_short, T_init_seq), :n_samples, :]
        train_data = train_data_shortened
        T = T_short
        n_samples = train_data_shortened.shape[1]

    print ('new number of samples', n_samples)

    if (args.model_type == 'ONET'):
        method = 1
    elif (args.model_type == 'HNET'):
        method = 5
    elif (args.model_type == 'RNN'):
        method = 3
    elif (args.model_type == 'LSTM'):
        method = 4
    else:
        raise ValueError('model_type not supported')

    if (args.leapfrog_train == 'true'):
        integrator_train = 'leapfrog'
    else:
        integrator_train = 'euler'
    if (args.leapfrog_test == 'true'):
        integrator_test = 'leapfrog'
    else:
        integrator_test = 'euler'


    log_dir = './logs/' + system_type + '_' + str(dataset_index) + str(run_index) + '/m' + str(method)
    pred_out_string = pred_dir + '/traj_pred_' + str(system_type) + '_' + str(method) + '_' + str(test_dataset_index) + '_' + str(run_index)

    if (not os.path.isdir(log_dir)):
        os.mkdir(log_dir)

    logger0 = Logger(os.path.join(log_dir, 'trloss.log'), print_out=True)
    logger1 = Logger(os.path.join(log_dir, 'teloss.log'), print_out=True)

    start = time.time()

    ## Training the model
    model, loss_record, val_loss_record = train(train_data, method=method, T=T, batch_size=batch_size, \
        n_epochs=n_epochs, n_samples=n_samples, n_val_samples=args.n_val_samples, dt=dt, lr=lr, n_layers=n_layers, \
        n_hidden=n_hidden, integrator_train=integrator_train, integrator_test=integrator_test, logger=logger0, device=device, \
        test_data=test_data, n_test_samples=n_test_samples, T_test=T_test, scheduler_type=scheduler_type, \
        scheduler_patience=scheduler_patience, scheduler_factor=scheduler_factor, pred_out_string=pred_out_string, \
        max_iters_init_train=max_iters_init_train, max_iters_init_test=max_iters_init_test, T_init_seq=T_init_seq, \
        coarsening_factor_test=coarsening_factor_test, test_freq=test_freq)


    f = open(os.path.join(log_dir, 'loss.pkl'), 'wb')
    # pickle.dump([np.array(loss_record), np.array(loss_record_val), t], f)
    pickle.dump([np.array(loss_record), np.array(val_loss_record)], f)
    f.close()
    loss_plot(os.path.join(log_dir, 'loss.pkl'), log_dir, name=['','','loss p','loss q'], teplotfreq = test_freq)
    loss_plot_restricted(os.path.join(log_dir, 'loss.pkl'), log_dir, name=['','','loss p','loss q'], teplotfreq = test_freq)

    train_time = time.time() - start
    print ('training with method ' + str(method) + ' costs time ', train_time)

    ## Predicting the test trajectories
    traj_pred = predict(test_data_init_seq=test_data[:T_init_seq, :, :], model=model, method=method, T_test=T_test, \
        n_test_samples=n_test_samples, dt=dt, integrator=integrator_test, device=device, max_iters_init=max_iters_init_test, \
        coarsening_factor=coarsening_factor_test)
    
    pred_time = time.time() - start
    print ('making the predictions with method ' + str(method) + ' costs time ', pred_time)

    np.save(pred_out_string + '.npy', traj_pred.cpu().data.numpy())
    torch.save(model.cpu(), model_dir + '/model_' + system_type + '_' + str(method) + '_' + str(dataset_index) + '_' + str(run_index))

    print ('done saving the predicted trajectory and trained model')
示例#6
0
parser.add_argument('--lr', type=float, default=0.)
parser.add_argument('--wd', type=float, default=0.)
parser.add_argument('--batch_size', type=int, default=0)
parser.add_argument('--n_epoch', type=int, default=0)
args = parser.parse_args()

np.set_printoptions(linewidth=150, precision=4, suppress=True)
th.set_printoptions(linewidth=150, precision=4)

FN = th.from_numpy
join = os.path.join
logger = logging.getLogger()

utils.prepare_directory(args.exp_root, force_delete=True)
utils.init_logger(join(args.exp_root, 'program.log'))
utils.write_args(args)

dset = data.XianDataset(args.data_dir,
                        args.mode,
                        feature_norm=args.feature_norm)
_X_s_tr = FN(dset.X_s_tr).to(args.device)
_Y_s_tr = FN(dset.Y_s_tr).to(args.device)
_X_s_te = FN(dset.X_s_te).to(args.device)
_Y_s_te = FN(dset.Y_s_te).to(args.device)
_X_u_te = FN(dset.X_u_te).to(args.device)
_Y_u_te = FN(dset.Y_u_te).to(args.device)
_Cu = FN(dset.Cu).to(args.device)
_Sall = FN(dset.Sall).to(args.device)

train_iter = data.Iterator([_X_s_tr, _Y_s_tr],
                           args.batch_size,
示例#7
0
def main():

    utils.prepare_directory(args.exp_dir, force_delete=False)
    utils.init_logger(join(args.exp_dir, 'program.log'))
    utils.write_args(args)

    # **************************************** load dataset ****************************************
    dset = data.XianDataset(args.data_dir,
                            args.mode,
                            feature_norm=args.feature_norm)
    _X_s_tr = FN(dset.X_s_tr).to(args.device)
    _Y_s_tr_ix = FN(dil(dset.Y_s_tr,
                        dset.Cs)).to(args.device)  # indexed labels
    _Ss = FN(dset.Sall[dset.Cs]).to(args.device)
    _Su = FN(dset.Sall[dset.Cu]).to(args.device)
    if args.d_noise == 0: args.d_noise = dset.d_attr

    # **************************************** create data loaders ****************************************
    _sampling_weights = None
    if args.dataset != 'SUN':
        _sampling_weights = data.compute_sampling_weights(
            dil(dset.Y_s_tr, dset.Cs)).to(args.device)
    xy_iter = data.Iterator([_X_s_tr, _Y_s_tr_ix],
                            args.batch_size,
                            sampling_weights=_sampling_weights)
    label_iter = data.Iterator([torch.arange(dset.n_Cs, device=args.device)],
                               args.batch_size)
    class_iter = data.Iterator([torch.arange(dset.n_Cs)], 1)

    # **************************************** per-class means and stds ****************************************
    # per class samplers and first 2 class moments
    per_class_iters = []
    Xs_tr_mean, Xs_tr_std = [], []
    Xs_te_mean, Xs_te_std = [], []
    Xu_te_mean, Xu_te_std = [], []
    for c_ix, c in enumerate(dset.Cs):
        # training samples of seen classes
        _inds = np.where(dset.Y_s_tr == c)[0]
        assert _inds.shape[0] > 0
        _X = dset.X_s_tr[_inds]
        Xs_tr_mean.append(_X.mean(axis=0, keepdims=True))
        Xs_tr_std.append(_X.std(axis=0, keepdims=True))

        if args.n_gm_iter > 0:
            _y = np.ones([_inds.shape[0]], np.int64) * c_ix
            per_class_iters.append(
                data.Iterator([FN(_X).to(args.device),
                               FN(_y).to(args.device)],
                              args.per_class_batch_size))

        # test samples of seen classes
        _inds = np.where(dset.Y_s_te == c)[0]
        assert _inds.shape[0] > 0
        _X = dset.X_s_te[_inds]
        Xs_te_mean.append(_X.mean(axis=0, keepdims=True))
        Xs_te_std.append(_X.std(axis=0, keepdims=True))

    # test samples of unseen classes
    for c_ix, c in enumerate(dset.Cu):
        _inds = np.where(dset.Y_u_te == c)[0]
        assert _inds.shape[0] > 0
        _X = dset.X_u_te[_inds]
        Xu_te_mean.append(_X.mean(axis=0, keepdims=True))
        Xu_te_std.append(_X.std(axis=0, keepdims=True))
    del _X, _inds, c_ix, c

    Xs_tr_mean = FN(np.concatenate(Xs_tr_mean, axis=0)).to(args.device)
    Xs_tr_std = FN(np.concatenate(Xs_tr_std, axis=0)).to(args.device)
    Xs_te_mean = FN(np.concatenate(Xs_te_mean, axis=0)).to(args.device)
    Xs_te_std = FN(np.concatenate(Xs_te_std, axis=0)).to(args.device)
    Xu_te_mean = FN(np.concatenate(Xu_te_mean, axis=0)).to(args.device)
    Xu_te_std = FN(np.concatenate(Xu_te_std, axis=0)).to(args.device)

    # **************************************** create networks ****************************************
    g_net = modules.get_generator(args.gen_type)(
        dset.d_attr, args.d_noise, args.n_g_hlayer, args.n_g_hunit,
        args.normalize_noise, args.dp_g, args.leakiness_g).to(args.device)
    g_optim = optim.Adam(g_net.parameters(),
                         args.gan_optim_lr_g,
                         betas=(args.gan_optim_beta1, args.gan_optim_beta2),
                         weight_decay=args.gan_optim_wd)

    d_net = modules.ConditionalDiscriminator(dset.d_attr, args.n_d_hlayer,
                                             args.n_d_hunit,
                                             args.d_normalize_ft, args.dp_d,
                                             args.leakiness_d).to(args.device)
    d_optim = optim.Adam(d_net.parameters(),
                         args.gan_optim_lr_d,
                         betas=(args.gan_optim_beta1, args.gan_optim_beta2),
                         weight_decay=args.gan_optim_wd)
    start_it = 1

    utils.model_info(g_net, 'g_net', args.exp_dir)
    utils.model_info(d_net, 'd_net', args.exp_dir)

    if args.n_gm_iter > 0:
        if args.clf_type == 'bilinear-comp':
            clf = classifiers.BilinearCompatibility(dset.d_ft, dset.d_attr,
                                                    args)
        elif args.clf_type == 'mlp':
            clf = classifiers.MLP(dset.d_ft, dset.n_Cs, args)
        utils.model_info(clf.net, 'clf', args.exp_dir)

    pret_clf = None
    if os.path.isfile(args.pretrained_clf_ckpt):
        logger.info('Loading pre-trained {} checkpoint at {} ...'.format(
            args.clf_type, args.pretrained_clf_ckpt))
        ckpt = torch.load(args.pretrained_clf_ckpt, map_location=args.device)
        pret_clf = classifiers.BilinearCompatibility(dset.d_ft, dset.d_attr,
                                                     args)
        pret_clf.net.load_state_dict(ckpt[args.clf_type])
        pret_clf.net.eval()
        for p in pret_clf.net.parameters():
            p.requires_grad = False

    pret_regg = None
    if os.path.isfile(args.pretrained_regg_ckpt):
        logger.info(
            'Loading pre-trained regressor checkpoint at {} ...'.format(
                args.pretrained_regg_ckpt))
        ckpt = torch.load(args.pretrained_regg_ckpt, map_location=args.device)
        pret_regg = classifiers.Regressor(args, dset.d_ft, dset.d_attr)
        pret_regg.net.load_state_dict(ckpt['regressor'])
        pret_regg.net.eval()
        for p in pret_regg.net.parameters():
            p.requires_grad = False

    training_log_titles = [
        'd/loss',
        'd/real',
        'd/fake',
        'd/penalty',
        'gm/loss',
        'gm/real_loss',
        'gm/fake_loss',
        'g/fcls_loss',
        'g/cycle_loss',
        'clf/train_loss',
        'clf/train_acc',
        'mmad/X_s_tr',
        'mmad/X_s_te',
        'mmad/X_u_te',
        'smad/X_s_tr',
        'smad/X_s_te',
        'smad/X_u_te',
    ]
    if args.n_gm_iter > 0:
        training_log_titles.extend([
            'grad-cossim/{}'.format(n) for n, p in clf.net.named_parameters()
        ])
        training_log_titles.extend(
            ['grad-mse/{}'.format(n) for n, p in clf.net.named_parameters()])
    training_logger = utils.Logger(os.path.join(args.exp_dir, 'training-logs'),
                                   'logs', training_log_titles)

    t0 = time.time()

    logger.info('penguenler olmesin')
    for it in range(start_it, args.n_iter + 1):

        # **************************************** Discriminator updates ****************************************
        for p in d_net.parameters():
            p.requires_grad = True
        for p in g_net.parameters():
            p.requires_grad = False
        for _ in range(args.n_d_iter):
            x_real, y_ix = next(xy_iter)
            s = _Ss[y_ix]
            x_fake = g_net(s)

            d_real = d_net(x_real, s).mean()
            d_fake = d_net(x_fake, s).mean()
            d_penalty = modules.gradient_penalty(d_net, x_real, x_fake, s)
            d_loss = d_fake - d_real + args.L * d_penalty

            d_optim.zero_grad()
            d_loss.backward()
            d_optim.step()

            training_logger.update_meters(
                ['d/real', 'd/fake', 'd/loss', 'd/penalty'], [
                    d_real.mean().item(),
                    d_fake.mean().item(),
                    d_loss.item(),
                    d_penalty.item()
                ], x_real.size(0))

        # **************************************** Generator updates ****************************************
        for p in d_net.parameters():
            p.requires_grad = False
        for p in g_net.parameters():
            p.requires_grad = True
        g_optim.zero_grad()

        [y_fake] = next(label_iter)
        s = _Ss[y_fake]
        x_fake = g_net(s)

        # wgan loss
        d_fake = d_net(x_fake, s).mean()
        g_wganloss = -d_fake

        # f-cls loss
        fcls_loss = 0.0
        if pret_clf is not None:
            fcls_loss = pret_clf.loss(x_fake, _Ss, y_fake)
            training_logger.update_meters(['g/fcls_loss'], [fcls_loss.item()],
                                          x_fake.size(0))

        # cycle-loss
        cycle_loss = 0.0
        if pret_regg is not None:
            cycle_loss = pret_regg.loss(x_fake, s)
            training_logger.update_meters(['g/cycle_loss'],
                                          [cycle_loss.item()], x_fake.size(0))

        g_loss = args.C * fcls_loss + args.R * cycle_loss + g_wganloss
        g_loss.backward()

        # gmn iterations
        for _ in range(args.n_gm_iter):
            c = next(class_iter)[0].item()
            x_real, y_real = next(per_class_iters[c])
            y_fake = y_real.detach().repeat(args.gm_fake_repeat)
            s = _Ss[y_fake]
            x_fake = g_net(s)

            # gm loss
            clf.net.zero_grad()
            if args.clf_type == 'bilinear-comp':
                real_loss = clf.loss(x_real, _Ss, y_real)
                fake_loss = clf.loss(x_fake, _Ss, y_fake)
            elif args.clf_type == 'mlp':
                real_loss = clf.loss(x_real, y_real)
                fake_loss = clf.loss(x_fake, y_fake)

            grad_cossim = []
            grad_mse = []
            for n, p in clf.net.named_parameters():
                # if len(p.shape) == 1: continue

                real_grad = grad([real_loss], [p],
                                 create_graph=True,
                                 only_inputs=True)[0]
                fake_grad = grad([fake_loss], [p],
                                 create_graph=True,
                                 only_inputs=True)[0]

                if len(p.shape) > 1:
                    _cossim = F.cosine_similarity(fake_grad, real_grad,
                                                  dim=1).mean()
                else:
                    _cossim = F.cosine_similarity(fake_grad, real_grad, dim=0)

                # _cossim = F.cosine_similarity(fake_grad, real_grad, dim=1).mean()
                _mse = F.mse_loss(fake_grad, real_grad)
                grad_cossim.append(_cossim)
                grad_mse.append(_mse)

                training_logger.update_meters(
                    ['grad-cossim/{}'.format(n), 'grad-mse/{}'.format(n)],
                    [_cossim.item(), _mse.item()], x_real.size(0))

            grad_cossim = torch.stack(grad_cossim)
            grad_mse = torch.stack(grad_mse)
            gm_loss = (1.0 -
                       grad_cossim).sum() * args.Q + grad_mse.sum() * args.Z
            gm_loss.backward()

            training_logger.update_meters(
                ['gm/real_loss', 'gm/fake_loss'],
                [real_loss.item(), fake_loss.item()], x_real.size(0))

        g_optim.step()

        # **************************************** Classifier update ****************************************
        if args.n_gm_iter > 0:
            if it % args.clf_reset_iter == 0:
                if args.clf_reset_iter == 1:
                    # no need to generate optimizer each time
                    clf.init_params()
                else:
                    clf.reset()
            else:
                x, y_ix = next(xy_iter)
                if args.clf_type == 'bilinear-comp':
                    clf_acc, clf_loss = clf.train_step(x, _Ss, y_ix)
                else:
                    clf_acc, clf_loss = clf.train_step(x, y_ix)
                training_logger.update_meters(
                    ['clf/train_loss', 'clf/train_acc'], [clf_loss, clf_acc],
                    x.size(0))

        # **************************************** Log ****************************************
        if it % 1000 == 0:
            g_net.eval()

            # synthesize samples for seen classes and compute their first 2 moments
            Xs_fake_mean, Xs_fake_std = [], []
            with torch.no_grad():
                for c in range(dset.n_Cs):
                    y = torch.ones(256, device=args.device,
                                   dtype=torch.long) * c
                    a = _Ss[y]
                    x_fake = g_net(a)
                    Xs_fake_mean.append(x_fake.mean(dim=0, keepdim=True))
                    Xs_fake_std.append(x_fake.std(dim=0, keepdim=True))
            Xs_fake_mean = torch.cat(Xs_fake_mean)
            Xs_fake_std = torch.cat(Xs_fake_std)

            # synthesize samples for unseen classes and compute their first 2 moments
            def compute_firsttwo_moments(S, C):
                X_mean, X_std = [], []
                with torch.no_grad():
                    for c in range(dset.n_Cu):
                        y = torch.ones(
                            256, device=args.device, dtype=torch.long) * c
                        a = _Su[y]
                        x_fake = g_net(a)
                        X_mean.append(x_fake.mean(dim=0, keepdim=True))
                        X_std.append(x_fake.std(dim=0, keepdim=True))
                X_mean = torch.cat(X_mean)
                X_std = torch.cat(X_std)

            Xu_fake_mean, Xu_fake_std = [], []
            with torch.no_grad():
                for c in range(dset.n_Cu):
                    y = torch.ones(256, device=args.device,
                                   dtype=torch.long) * c
                    a = _Su[y]
                    x_fake = g_net(a)
                    Xu_fake_mean.append(x_fake.mean(dim=0, keepdim=True))
                    Xu_fake_std.append(x_fake.std(dim=0, keepdim=True))
            Xu_fake_mean = torch.cat(Xu_fake_mean)
            Xu_fake_std = torch.cat(Xu_fake_std)

            g_net.train()

            training_logger.update_meters([
                'mmad/X_s_tr', 'smad/X_s_tr', 'mmad/X_s_te', 'smad/X_s_te',
                'mmad/X_u_te', 'smad/X_u_te'
            ], [
                torch.abs(Xs_tr_mean - Xs_fake_mean).sum(dim=1).mean().item(),
                torch.abs(Xs_tr_std - Xs_fake_std).sum(dim=1).mean().item(),
                torch.abs(Xs_te_mean - Xs_fake_mean).sum(dim=1).mean().item(),
                torch.abs(Xs_te_std - Xs_fake_std).sum(dim=1).mean().item(),
                torch.abs(Xu_te_mean - Xu_fake_mean).sum(dim=1).mean().item(),
                torch.abs(Xu_te_std - Xu_fake_std).sum(dim=1).mean().item()
            ])

            training_logger.flush_meters(it)

            elapsed = time.time() - t0
            per_iter = elapsed / it
            apprx_rem = (args.n_iter - it) * per_iter
            logging.info('Iter:{:06d}/{:06d}, '\
                         '[ET:{:.1e}(min)], ' \
                         '[IT:{:.1f}(ms)], ' \
                         '[REM:{:.1e}(min)]'.format(
                            it, args.n_iter, elapsed / 60., per_iter * 1000., apprx_rem / 60))

        if it % 10000 == 0:
            utils.save_checkpoint(
                {
                    'g_net': g_net.state_dict(),
                    'd_net': d_net.state_dict(),
                    'g_optim': g_optim.state_dict(),
                    'd_optim': d_optim.state_dict(),
                    'iteration': it
                },
                args.exp_dir,
                None,
                it if it % (args.n_iter // args.n_ckpt) == 0 else None,
            )

    training_logger.close()