示例#1
0
def main(model_name, in_memory, complex_, model, local_data, epochs, fourier,
         stft, fast_load):
    rng = numpy.random.RandomState(123)

    # Warning: the full dataset is over 40GB. Make sure you have enough RAM!
    # This can take a few minutes to load
    if in_memory:
        print('.. loading train data')
        dataset = MusicNet(local_data,
                           complex_=complex_,
                           fourier=fourier,
                           stft=stft,
                           rng=rng,
                           fast_load=fast_load)
        dataset.load()
        print('.. train data loaded')
        Xvalid, Yvalid = dataset.eval_set('valid')
        Xtest, Ytest = dataset.eval_set('test')
    else:
        raise ValueError

    print(".. building model")
    model = get_model(model, dataset.feature_dim)

    model.summary()
    print(".. parameters: {:03.2f}M".format(model.count_params() / 1000000.))

    if in_memory:
        pass
        # do nothing
    else:
        raise ValueError

    logger = mimir.Logger(filename='models/log_{}.jsonl.gz'.format(model_name))

    it = dataset.train_iterator()

    callbacks = [
        Validation(Xvalid, Yvalid, 'valid', logger),
        Validation(Xtest, Ytest, 'test', logger),
        SaveLastModel("./models/", 1, name=model),
        Performance(logger),
        LearningRateScheduler(schedule)
    ]

    print('.. start training')
    model.fit_generator(it,
                        steps_per_epoch=1000,
                        epochs=epochs,
                        callbacks=callbacks,
                        workers=1)
示例#2
0
def train(args, model_args):

    #model_id = '/data/lisatmp4/lambalex/lsun_walkback/walkback_'

    model_id = '/data/lisatmp4/anirudhg/cifar_walk_back/walkback_'
    model_dir = create_log_dir(args, model_id)
    model_id2 = 'logs/walkback_'
    model_dir2 = create_log_dir(args, model_id2)
    print model_dir
    print model_dir2 + '/' + 'log.jsonl.gz'
    logger = mimir.Logger(filename=model_dir2 + '/log.jsonl.gz',
                          formatter=None)

    # TODO batches_per_epoch should not be hard coded
    lrate = args.lr
    import sys
    sys.setrecursionlimit(10000000)
    args, model_args = parse_args()

    #trng = RandomStreams(1234)

    if args.resume_file is not None:
        print "Resuming training from " + args.resume_file
        from blocks.scripts import continue_training
        continue_training(args.resume_file)

    ## load the training data
    if args.dataset == 'MNIST':
        print 'loading MNIST'
        from fuel.datasets import MNIST
        dataset_train = MNIST(['train'], sources=('features', ))
        dataset_test = MNIST(['test'], sources=('features', ))
        n_colors = 1
        spatial_width = 28

    elif args.dataset == 'CIFAR10':
        from fuel.datasets import CIFAR10
        dataset_train = CIFAR10(['train'], sources=('features', ))
        dataset_test = CIFAR10(['test'], sources=('features', ))
        n_colors = 3
        spatial_width = 32

    elif args.dataset == "lsun" or args.dataset == "lsunsmall":

        print "loading lsun class!"

        from load_lsun import load_lsun

        print "loading lsun data!"

        if args.dataset == "lsunsmall":
            dataset_train, dataset_test = load_lsun(args.batch_size,
                                                    downsample=True)
            spatial_width = 32
        else:
            dataset_train, dataset_test = load_lsun(args.batch_size,
                                                    downsample=False)
            spatial_width = 64

        n_colors = 3

    elif args.dataset == "celeba":

        print "loading celeba data"

        from fuel.datasets.celeba import CelebA

        dataset_train = CelebA(which_sets=['train'],
                               which_format="64",
                               sources=('features', ),
                               load_in_memory=False)
        dataset_test = CelebA(which_sets=['test'],
                              which_format="64",
                              sources=('features', ),
                              load_in_memory=False)

        spatial_width = 64
        n_colors = 3

        tr_scheme = SequentialScheme(examples=dataset_train.num_examples,
                                     batch_size=args.batch_size)
        ts_scheme = SequentialScheme(examples=dataset_test.num_examples,
                                     batch_size=args.batch_size)

        train_stream = DataStream.default_stream(dataset_train,
                                                 iteration_scheme=tr_scheme)
        test_stream = DataStream.default_stream(dataset_test,
                                                iteration_scheme=ts_scheme)

        dataset_train = train_stream
        dataset_test = test_stream

        #epoch_it = train_stream.get_epoch_iterator()

    elif args.dataset == 'Spiral':
        print 'loading SPIRAL'
        train_set = Spiral(num_examples=100000,
                           classes=1,
                           cycles=2.,
                           noise=0.01,
                           sources=('features', ))
        dataset_train = DataStream.default_stream(
            train_set,
            iteration_scheme=ShuffledScheme(train_set.num_examples,
                                            args.batch_size))

    else:
        raise ValueError("Unknown dataset %s." % args.dataset)

    model_options = locals().copy()

    if args.dataset != 'lsun' and args.dataset != 'celeba':
        train_stream = Flatten(
            DataStream.default_stream(
                dataset_train,
                iteration_scheme=ShuffledScheme(
                    examples=dataset_train.num_examples -
                    (dataset_train.num_examples % args.batch_size),
                    batch_size=args.batch_size)))
    else:
        train_stream = dataset_train
        test_stream = dataset_test

    print "Width", WIDTH, spatial_width

    shp = next(train_stream.get_epoch_iterator())[0].shape

    print "got epoch iterator"

    # make the training data 0 mean and variance 1
    # TODO compute mean and variance on full dataset, not minibatch
    Xbatch = next(train_stream.get_epoch_iterator())[0]
    scl = 1. / np.sqrt(np.mean((Xbatch - np.mean(Xbatch))**2))
    shft = -np.mean(Xbatch * scl)
    # scale is applied before shift
    #train_stream = ScaleAndShift(train_stream, scl, shft)
    #test_stream = ScaleAndShift(test_stream, scl, shft)

    print 'Building model'
    params = init_params(model_options)
    if args.reload_:
        print "Trying to reload parameters"
        if os.path.exists(args.saveto_filename):
            print 'Reloading Parameters'
            print args.saveto_filename
            params = load_params(args.saveto_filename, params)
    tparams = init_tparams(params)
    print tparams
    '''
    x = T.matrix('x', dtype='float32')
    temp  = T.scalar('temp', dtype='float32')
    f=transition_operator(tparams, model_options, x, temp)

    for data in train_stream.get_epoch_iterator():
        print data[0]
        a = f([data[0], 1.0, 1])
        #ipdb.set_trace()
    '''
    x, cost, start_temperature = build_model(tparams, model_options)
    inps = [x, start_temperature]

    x_Data = T.matrix('x_Data', dtype='float32')
    temperature = T.scalar('temperature', dtype='float32')
    forward_diffusion = one_step_diffusion(x_Data, model_options, tparams,
                                           temperature)

    #print 'Building f_cost...',
    #f_cost = theano.function(inps, cost)
    #print 'Done'
    print tparams
    grads = T.grad(cost, wrt=itemlist(tparams))

    #get_grads = theano.function(inps, grads)

    for j in range(0, len(grads)):
        grads[j] = T.switch(T.isnan(grads[j]), T.zeros_like(grads[j]),
                            grads[j])

    # compile the optimizer, the actual computational graph is compiled here
    lr = T.scalar(name='lr')
    print 'Building optimizers...',
    optimizer = args.optimizer

    f_grad_shared, f_update = getattr(optimizers, optimizer)(lr, tparams,
                                                             grads, inps, cost)
    print 'Done'

    for param in tparams:
        print param
        print tparams[param].get_value().shape

    print 'Buiding Sampler....'
    f_sample = sample(tparams, model_options)
    print 'Done'

    uidx = 0
    estop = False
    bad_counter = 0
    max_epochs = 4000
    batch_index = 1
    print 'Number of steps....'
    print args.num_steps
    print "Number of metasteps...."
    print args.meta_steps
    print 'Done'
    count_sample = 1
    for eidx in xrange(max_epochs):
        if eidx % 20 == 0:
            params = unzip(tparams)
            save_params(params,
                        model_dir + '/' + 'params_' + str(eidx) + '.npz')
        n_samples = 0
        print 'Starting Next Epoch ', eidx
        for data in train_stream.get_epoch_iterator():

            if args.dataset == 'CIFAR10':
                if data[0].shape[0] == args.batch_size:
                    data_use = (data[0].reshape(args.batch_size,
                                                3 * 32 * 32), )
                else:
                    continue
            t0 = time.time()
            batch_index += 1
            n_samples += len(data_use[0])
            uidx += 1
            if data_use[0] is None:
                print 'No data '
                uidx -= 1
                continue
            ud_start = time.time()

            t1 = time.time()

            data_run = data_use[0]
            temperature_forward = args.temperature
            meta_cost = []
            for meta_step in range(0, args.meta_steps):
                meta_cost.append(f_grad_shared(data_run, temperature_forward))
                f_update(lrate)
                if args.meta_steps > 1:
                    data_run, sigma, _, _ = forward_diffusion(
                        [data_run, temperature_forward, 1])
                    temperature_forward *= args.temperature_factor
            cost = sum(meta_cost) / len(meta_cost)

            ud = time.time() - ud_start

            #gradient_updates_ = get_grads(data_use[0],args.temperature)

            if np.isnan(cost) or np.isinf(cost):
                print 'NaN detected'
                return 1.
            t1 = time.time()
            #print time.time() - t1, "time to get grads"
            t1 = time.time()
            logger.log({
                'epoch': eidx,
                'batch_index': batch_index,
                'uidx': uidx,
                'training_error': cost
            })
            #'Norm_1': np.linalg.norm(gradient_updates_[0]),
            #'Norm_2': np.linalg.norm(gradient_updates_[1]),
            #'Norm_3': np.linalg.norm(gradient_updates_[2]),
            #'Norm_4': np.linalg.norm(gradient_updates_[3])})
            #print time.time() - t1, "time to log"

            #print time.time() - t0, "total time in batch"
            t5 = time.time()

            if batch_index % 20 == 0:
                print batch_index, "cost", cost

            if batch_index % 200 == 0:
                count_sample += 1
                temperature = args.temperature * (args.temperature_factor**(
                    args.num_steps * args.meta_steps - 1))
                temperature_forward = args.temperature

                for num_step in range(args.num_steps * args.meta_steps):
                    print "Forward temperature", temperature_forward
                    if num_step == 0:
                        x_data, sampled, sampled_activation, sampled_preactivation = forward_diffusion(
                            [data_use[0], temperature_forward, 1])
                        x_data = np.asarray(x_data).astype('float32').reshape(
                            args.batch_size, INPUT_SIZE)
                        x_temp = x_data.reshape(args.batch_size, n_colors,
                                                WIDTH, WIDTH)
                        plot_images(
                            x_temp, model_dir + '/' + "batch_" +
                            str(batch_index) + '_corrupted' + 'epoch_' +
                            str(count_sample) + '_time_step_' + str(num_step))
                    else:
                        x_data, sampled, sampled_activation, sampled_preactivation = forward_diffusion(
                            [x_data, temperature_forward, 1])
                        x_data = np.asarray(x_data).astype('float32').reshape(
                            args.batch_size, INPUT_SIZE)
                        x_temp = x_data.reshape(args.batch_size, n_colors,
                                                WIDTH, WIDTH)
                        plot_images(
                            x_temp, model_dir + '/batch_' + str(batch_index) +
                            '_corrupted' + '_epoch_' + str(count_sample) +
                            '_time_step_' + str(num_step))

                    temperature_forward = temperature_forward * args.temperature_factor

                x_temp2 = data_use[0].reshape(args.batch_size, n_colors, WIDTH,
                                              WIDTH)
                plot_images(
                    x_temp2, model_dir + '/' + 'orig_' + 'epoch_' + str(eidx) +
                    '_batch_index_' + str(batch_index))

                temperature = args.temperature * (args.temperature_factor**(
                    args.num_steps * args.meta_steps - 1))

                for i in range(args.num_steps * args.meta_steps +
                               args.extra_steps):
                    x_data, sampled, sampled_activation, sampled_preactivation = f_sample(
                        [x_data, temperature, 0])
                    print 'On backward step number, using temperature', i, temperature
                    reverse_time(
                        scl, shft, x_data, model_dir + '/' + "batch_" +
                        str(batch_index) + '_samples_backward_' + 'epoch_' +
                        str(count_sample) + '_time_step_' + str(i))
                    x_data = np.asarray(x_data).astype('float32')
                    x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                    if temperature == args.temperature:
                        temperature = temperature
                    else:
                        temperature /= args.temperature_factor

                if args.noise == "gaussian":
                    x_sampled = np.random.normal(
                        0.5, 2.0,
                        size=(args.batch_size, INPUT_SIZE)).clip(0.0, 1.0)
                else:
                    s = np.random.binomial(1, 0.5, INPUT_SIZE)

                temperature = args.temperature * (args.temperature_factor**(
                    args.num_steps * args.meta_steps - 1))

                x_data = np.asarray(x_sampled).astype('float32')
                for i in range(args.num_steps * args.meta_steps +
                               args.extra_steps):
                    x_data, sampled, sampled_activation, sampled_preactivation = f_sample(
                        [x_data, temperature, 0])
                    print 'On step number, using temperature', i, temperature
                    reverse_time(
                        scl, shft, x_data, model_dir + '/batch_index_' +
                        str(batch_index) + '_inference_' + 'epoch_' +
                        str(count_sample) + '_step_' + str(i))
                    x_data = np.asarray(x_data).astype('float32')
                    x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                    if temperature == args.temperature:
                        temperature = temperature
                    else:
                        temperature /= args.temperature_factor

    ipdb.set_trace()
def train(args, model_args):

    model_id = '/data/lisatmp4/anirudhg/spiral_walk_back/walkback_'
    model_dir = create_log_dir(args, model_id)
    model_id2 = 'logs/walkback_'
    model_dir2 = create_log_dir(args, model_id2)
    print model_dir
    print model_dir2 + '/' + 'log.jsonl.gz'
    logger = mimir.Logger(filename=model_dir2 + '/log.jsonl.gz',
                          formatter=None)

    # TODO batches_per_epoch should not be hard coded
    lrate = args.lr
    import sys
    sys.setrecursionlimit(10000000)
    args, model_args = parse_args()

    #trng = RandomStreams(1234)

    if args.resume_file is not None:
        print "Resuming training from " + args.resume_file
        from blocks.scripts import continue_training
        continue_training(args.resume_file)

    ## load the training data
    if args.dataset == 'MNIST':
        print 'loading MNIST'
        from fuel.datasets import MNIST
        dataset_train = MNIST(['train'], sources=('features', ))
        dataset_test = MNIST(['test'], sources=('features', ))
        n_colors = 1
        spatial_width = 28

    elif args.dataset == 'CIFAR10':
        from fuel.datasets import CIFAR10
        dataset_train = CIFAR10(['train'], sources=('features', ))
        dataset_test = CIFAR10(['test'], sources=('features', ))
        n_colors = 3
        spatial_width = 32

    elif args.dataset == "lsun" or args.dataset == "lsunsmall":

        print "loading lsun class!"

        from load_lsun import load_lsun

        print "loading lsun data!"

        if args.dataset == "lsunsmall":
            dataset_train, dataset_test = load_lsun(args.batch_size,
                                                    downsample=True)
            spatial_width = 32
        else:
            dataset_train, dataset_test = load_lsun(args.batch_size,
                                                    downsample=False)
            spatial_width = 64

        n_colors = 3

    elif args.dataset == "celeba":

        print "loading celeba data"

        from fuel.datasets.celeba import CelebA

        dataset_train = CelebA(which_sets=['train'],
                               which_format="64",
                               sources=('features', ),
                               load_in_memory=False)
        dataset_test = CelebA(which_sets=['test'],
                              which_format="64",
                              sources=('features', ),
                              load_in_memory=False)

        spatial_width = 64
        n_colors = 3

        tr_scheme = SequentialScheme(examples=dataset_train.num_examples,
                                     batch_size=args.batch_size)
        ts_scheme = SequentialScheme(examples=dataset_test.num_examples,
                                     batch_size=args.batch_size)

        train_stream = DataStream.default_stream(dataset_train,
                                                 iteration_scheme=tr_scheme)
        test_stream = DataStream.default_stream(dataset_test,
                                                iteration_scheme=ts_scheme)

        dataset_train = train_stream
        dataset_test = test_stream

        #epoch_it = train_stream.get_epoch_iterator()

    elif args.dataset == 'Spiral':
        print 'loading SPIRAL'
        train_set = Spiral(num_examples=20000,
                           classes=1,
                           cycles=1.,
                           noise=0.01,
                           sources=('features', ))
        dataset_train = DataStream.default_stream(
            train_set,
            iteration_scheme=ShuffledScheme(train_set.num_examples,
                                            args.batch_size))
    elif args.dataset == 'Circle':
        print 'loading Circle'
        train_set = Circle(num_examples=20000,
                           classes=1,
                           cycles=1.,
                           noise=0.0,
                           sources=('features', ))
        dataset_train = DataStream.default_stream(
            train_set,
            iteration_scheme=ShuffledScheme(train_set.num_examples,
                                            args.batch_size))
        iter_per_epoch = train_set.num_examples
    else:
        raise ValueError("Unknown dataset %s." % args.dataset)

    model_options = locals().copy()

    train_stream = dataset_train

    shp = next(train_stream.get_epoch_iterator())[0].shape

    print "got epoch iterator"

    # make the training data 0 mean and variance 1
    # TODO compute mean and variance on full dataset, not minibatch
    Xbatch = next(train_stream.get_epoch_iterator())[0]
    scl = 1. / np.sqrt(np.mean((Xbatch - np.mean(Xbatch))**2))
    shft = -np.mean(Xbatch * scl)
    # scale is applied before shift
    #train_stream = ScaleAndShift(train_stream, scl, shft)
    #test_stream = ScaleAndShift(test_stream, scl, shft)

    print 'Building model'
    params = init_params(model_options)
    if args.reload_:
        print "Trying to reload parameters"
        if os.path.exists(args.saveto_filename):
            print 'Reloading Parameters'
            print args.saveto_filename
            params = load_params(args.saveto_filename, params)
    tparams = init_tparams(params)
    print tparams
    x, cost, start_temperature = build_model(tparams, model_options)
    inps = [x, start_temperature]

    x_Data = T.matrix('x_Data', dtype='float32')
    temperature = T.scalar('temperature', dtype='float32')
    forward_diffusion = one_step_diffusion(x_Data, model_options, tparams,
                                           temperature)

    #print 'Building f_cost...',
    #f_cost = theano.function(inps, cost)
    #print 'Done'
    print tparams
    grads = T.grad(cost, wrt=itemlist(tparams))

    #get_grads = theano.function(inps, grads)

    for j in range(0, len(grads)):
        grads[j] = T.switch(T.isnan(grads[j]), T.zeros_like(grads[j]),
                            grads[j])

    # compile the optimizer, the actual computational graph is compiled here
    lr = T.scalar(name='lr')
    print 'Building optimizers...',
    optimizer = args.optimizer

    f_grad_shared, f_update = getattr(optimizers, optimizer)(lr, tparams,
                                                             grads, inps, cost)
    print 'Done'

    print 'Buiding Sampler....'
    f_sample = sample(tparams, model_options)
    print 'Done'
    uidx = 0
    estop = False
    bad_counter = 0
    max_epochs = 4000
    batch_index = 0
    print 'Number of steps....', args.num_steps
    print 'Done'
    count_sample = 1
    batch_index = 0
    for eidx in xrange(max_epochs):
        if eidx % 20 == 0:
            params = unzip(tparams)
            save_params(params,
                        model_dir + '/' + 'params_' + str(eidx) + '.npz')
            if eidx == 30:
                ipdb.set_trace()
        n_samples = 0
        print 'Starting Next Epoch ', eidx

        for data in train_stream.get_epoch_iterator():
            batch_index += 1
            n_samples += len(data[0])
            uidx += 1
            if data[0] is None:
                print 'No data '
                uidx -= 1
                continue
            data_run = data[0]
            temperature_forward = args.temperature
            meta_cost = []
            for meta_step in range(0, args.meta_steps):
                meta_cost.append(f_grad_shared(data_run, temperature_forward))
                f_update(lrate)
                if args.meta_steps > 1:
                    data_run, sigma, _, _ = forward_diffusion(
                        data_run, temperature_forward)
                    temperature_forward *= args.temperature_factor
            cost = sum(meta_cost) / len(meta_cost)
            if np.isnan(cost) or np.isinf(cost):
                print 'NaN detected'
                return 1.
            logger.log({
                'epoch': eidx,
                'batch_index': batch_index,
                'uidx': uidx,
                'training_error': cost
            })
            empty = []
            spiral_x = [empty for i in range(args.num_steps)]
            spiral_corrupted = []
            spiral_sampled = []
            grad_forward = []
            grad_back = []
            x_data_time = []
            x_tilt_time = []
            if batch_index % 8 == 0:
                count_sample += 1
                temperature = args.temperature * (args.temperature_factor
                                                  **(args.num_steps - 1))
                temperature_forward = args.temperature
                for num_step in range(args.num_steps):
                    if num_step == 0:
                        x_data_time.append(data[0])
                        plot_images(
                            data[0], model_dir + '/' + 'orig_' + 'epoch_' +
                            str(count_sample) + '_batch_' + str(batch_index))
                        x_data, mu_data, _, _ = forward_diffusion(
                            data[0], temperature_forward)

                        plot_images(
                            x_data, model_dir + '/' + 'corrupted_' + 'epoch_' +
                            str(count_sample) + '_batch_' + str(batch_index) +
                            '_time_step_' + str(num_step))
                        x_data_time.append(x_data)
                        temp_grad = np.concatenate(
                            (x_data_time[-2], x_data_time[-1]), axis=1)
                        grad_forward.append(temp_grad)

                        x_data = np.asarray(x_data).astype('float32').reshape(
                            args.batch_size, INPUT_SIZE)
                        spiral_corrupted.append(x_data)
                        mu_data = np.asarray(mu_data).astype(
                            'float32').reshape(args.batch_size, INPUT_SIZE)
                        mu_data = mu_data.reshape(args.batch_size, 2)
                    else:
                        x_data_time.append(x_data)
                        x_data, mu_data, _, _ = forward_diffusion(
                            x_data, temperature_forward)
                        plot_images(
                            x_data, model_dir + '/' + 'corrupted_' + 'epoch_' +
                            str(count_sample) + '_batch_' + str(batch_index) +
                            '_time_step_' + str(num_step))
                        x_data = np.asarray(x_data).astype('float32').reshape(
                            args.batch_size, INPUT_SIZE)
                        spiral_corrupted.append(x_data)

                        mu_data = np.asarray(mu_data).astype(
                            'float32').reshape(args.batch_size, INPUT_SIZE)
                        mu_data = mu_data.reshape(args.batch_size, 2)
                        x_data_time.append(x_data)
                        temp_grad = np.concatenate(
                            (x_data_time[-2], x_data_time[-1]), axis=1)
                        grad_forward.append(temp_grad)
                    temperature_forward = temperature_forward * args.temperature_factor

                mean_sampled = x_data.mean()
                var_sampled = x_data.var()

                x_temp2 = data[0].reshape(args.batch_size, 2)
                plot_2D(
                    spiral_corrupted, args.num_steps,
                    model_dir + '/' + 'corrupted_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
                plot_2D(
                    x_temp2, 1, model_dir + '/' + 'orig_' + 'epoch_' +
                    str(count_sample) + '_batch_index_' + str(batch_index))
                plot_grad(
                    grad_forward,
                    model_dir + '/' + 'grad_forward_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
                for i in range(args.num_steps + args.extra_steps):
                    x_tilt_time.append(x_data)
                    x_data, sampled_mean = f_sample(x_data, temperature)
                    plot_images(
                        x_data, model_dir + '/' + 'sampled_' + 'epoch_' +
                        str(count_sample) + '_batch_' + str(batch_index) +
                        '_time_step_' + str(i))
                    x_tilt_time.append(x_data)
                    temp_grad = np.concatenate(
                        (x_tilt_time[-2], x_tilt_time[-1]), axis=1)
                    grad_back.append(temp_grad)

                    ###print 'Recons, On step number, using temperature', i, temperature
                    x_data = np.asarray(x_data).astype('float32')
                    x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                    if temperature == args.temperature:
                        temperature = temperature
                    else:
                        temperature /= args.temperature_factor

                plot_grad(
                    grad_back, model_dir + '/' + 'grad_back_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
                plot_2D(
                    x_tilt_time, args.num_steps,
                    model_dir + '/' + 'sampled_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))

                s = np.random.normal(mean_sampled, var_sampled,
                                     [args.batch_size, 2])
                x_sampled = s

                temperature = args.temperature * (args.temperature_factor
                                                  **(args.num_steps - 1))
                x_data = np.asarray(x_sampled).astype('float32')
                for i in range(args.num_steps + args.extra_steps):
                    x_data, sampled_mean = f_sample(x_data, temperature)
                    spiral_sampled.append(x_data)
                    x_data = np.asarray(x_data).astype('float32')
                    x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                    if temperature == args.temperature:
                        temperature = temperature
                    else:
                        temperature /= args.temperature_factor
                plot_2D(
                    spiral_sampled, args.num_steps,
                    model_dir + '/' + 'inference_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
    ipdb.set_trace()
示例#4
0
def train(args, model_args, lrate):

    model_id = '/data/lisatmp4/anirudhg/minst_walk_back/walkback_'
    model_dir = create_log_dir(args, model_id)
    model_id2 = 'walkback_'
    model_dir2 = create_log_dir(args, model_id2)
    print model_dir
    logger = mimir.Logger(filename=model_dir2 + '/' + model_id2 +
                          'log.jsonl.gz',
                          formatter=None)

    # TODO batches_per_epoch should not be hard coded
    lrate = args.lr
    import sys
    sys.setrecursionlimit(10000000)
    args, model_args = parse_args()

    #trng = RandomStreams(1234)

    if args.resume_file is not None:
        print "Resuming training from " + args.resume_file
        from blocks.scripts import continue_training
        continue_training(args.resume_file)

    ## load the training data
    if args.dataset == 'MNIST':
        print 'loading MNIST'
        from fuel.datasets import MNIST
        dataset_train = MNIST(['train'], sources=('features', ))
        dataset_test = MNIST(['test'], sources=('features', ))
        n_colors = 1
        spatial_width = 28

    elif args.dataset == 'CIFAR10':
        from fuel.datasets import CIFAR10
        dataset_train = CIFAR10(['train'], sources=('features', ))
        dataset_test = CIFAR10(['test'], sources=('features', ))
        n_colors = 3
        spatial_width = 32

    elif args.dataset == 'Spiral':
        print 'loading SPIRAL'
        train_set = Spiral(num_examples=100000,
                           classes=1,
                           cycles=2.,
                           noise=0.01,
                           sources=('features', ))
        dataset_train = DataStream.default_stream(
            train_set,
            iteration_scheme=ShuffledScheme(train_set.num_examples,
                                            args.batch_size))

    else:
        raise ValueError("Unknown dataset %s." % args.dataset)

    model_options = locals().copy()

    train_stream = Flatten(
        DataStream.default_stream(dataset_train,
                                  iteration_scheme=ShuffledScheme(
                                      examples=dataset_train.num_examples,
                                      batch_size=args.batch_size)))

    shp = next(train_stream.get_epoch_iterator())[0].shape
    # make the training data 0 mean and variance 1
    # TODO compute mean and variance on full dataset, not minibatch
    Xbatch = next(train_stream.get_epoch_iterator())[0]
    scl = 1. / np.sqrt(np.mean((Xbatch - np.mean(Xbatch))**2))
    shft = -np.mean(Xbatch * scl)
    # scale is applied before shift
    #train_stream = ScaleAndShift(train_stream, scl, shft)
    #test_stream = ScaleAndShift(test_stream, scl, shft)

    print 'Building model'
    params = init_params(model_options)
    if args.reload_ and os.path.exists(args.saveto_filename):
        print 'Reloading Parameters'
        print args.saveto_filename
        params = load_params(args.saveto_filename, params)
    tparams = init_tparams(params)
    '''
    x = T.matrix('x', dtype='float32')
    f=transition_operator(tparams, model_options, x, 1)

    for data in train_stream.get_epoch_iterator():
        print data[0]
        a = f(data[0])
        print a
        ipdb.set_trace()
    '''
    x, cost = build_model(tparams, model_options)
    inps = [x]

    x_Data = T.matrix('x_Data', dtype='float32')
    temperature = T.scalar('temperature', dtype='float32')
    forward_diffusion = one_step_diffusion(x_Data, model_options, tparams,
                                           temperature)

    print 'Building f_cost...',
    f_cost = theano.function(inps, cost)
    print 'Done'
    print tparams
    grads = T.grad(cost, wrt=itemlist(tparams))

    get_grads = theano.function(inps, grads)

    for j in range(0, len(grads)):
        grads[j] = T.switch(T.isnan(grads[j]), T.zeros_like(grads[j]),
                            grads[j])

    # compile the optimizer, the actual computational graph is compiled here
    lr = T.scalar(name='lr')
    print 'Building optimizers...',
    optimizer = args.optimizer

    f_grad_shared, f_update = getattr(optimizers, optimizer)(lr, tparams,
                                                             grads, inps, cost)
    print 'Done'

    print 'Buiding Sampler....'
    f_sample = sample(tparams, model_options)
    print 'Done'

    uidx = 0
    estop = False
    bad_counter = 0
    max_epochs = 4000
    batch_index = 0
    print 'Number of steps....'
    print args.num_steps
    print 'Done'
    count_sample = 1
    for eidx in xrange(max_epochs):
        n_samples = 0
        print 'Starting Next Epoch ', eidx
        for data in train_stream.get_epoch_iterator():
            batch_index += 1
            n_samples += len(data[0])
            uidx += 1
            if data[0] is None:
                print 'No data '
                uidx -= 1
                continue
            ud_start = time.time()
            cost = f_grad_shared(data[0])
            f_update(lrate)
            ud = time.time() - ud_start

            if batch_index % 1 == 0:
                print 'Cost is this', cost
                count_sample += 1

                from impainting import change_image, inpainting
                train_temp = data[0]
                print data[0].shape
                change_image(train_temp.reshape(args.batch_size, 1, 28, 28), 3)
                train_temp = train_temp.reshape(args.batch_size, 784)
                output = inpainting(train_temp)
                change_image(output.reshape(args.batch_size, 1, 28, 28), 1)

                reverse_time(
                    scl, shft, output,
                    model_dir + '/' + 'impainting_orig_' + 'epoch_' +
                    str(count_sample) + '_batch_index_' + str(batch_index))
                x_data = np.asarray(output).astype('float32')
                temperature = args.temperature * (args.temperature_factor
                                                  **(args.num_steps - 1))
                temperature = args.temperature  #* (args.temperature_factor ** (args.num_steps -1 ))
                orig_impainted_data = np.asarray(data[0]).astype('float32')

                for i in range(args.num_steps + args.extra_steps + 5):
                    x_data, sampled, sampled_activation, sampled_preactivation = f_sample(
                        x_data, temperature)
                    print 'Impainting using temperature', i, temperature
                    x_data = do_half_image(x_data, orig_impainted_data)
                    reverse_time(
                        scl, shft, x_data, model_dir + '/' +
                        'impainting_orig_' + 'epoch_' + str(count_sample) +
                        '_batch_index_' + str(batch_index) + 'step_' + str(i))
                    x_data = np.asarray(x_data).astype('float32')
                    x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                    if temperature == args.temperature:
                        temperature = temperature
                    else:
                        temperature = temperature
                        #temperature /= args.temperature_factor
    ipdb.set_trace()
def train(args, model_args):

    model_id = '/data/lisatmp3/anirudhg/spiral_walk_back/walkback_'
    model_dir = create_log_dir(args, model_id)
    model_id2 = 'walkback_'
    model_dir2 = create_log_dir(args, model_id2)
    print model_dir, model_dir2
    logger = mimir.Logger(filename=model_dir2 + '/' + model_id2 +
                          'log.jsonl.gz',
                          formatter=None)

    # TODO batches_per_epoch should not be hard coded
    lrate = args.lr
    import sys
    sys.setrecursionlimit(10000000)
    args, model_args = parse_args()

    #trng = RandomStreams(1234)

    if args.resume_file is not None:
        print "Resuming training from " + args.resume_file
        from blocks.scripts import continue_training
        continue_training(args.resume_file)

    ## load the training data
    if args.dataset == 'MNIST':
        print 'loading MNIST'
        from fuel.datasets import MNIST
        dataset_train = MNIST(['train'], sources=('features', ))
        dataset_test = MNIST(['test'], sources=('features', ))
        n_colors = 1
        spatial_width = 28

    elif args.dataset == 'Spiral':
        print 'loading SPIRAL'
        train_set = Spiral(num_examples=20000,
                           classes=1,
                           cycles=1.,
                           noise=0.0,
                           sources=('features', ))
        dataset_train = DataStream.default_stream(
            train_set,
            iteration_scheme=ShuffledScheme(train_set.num_examples,
                                            args.batch_size))
        iter_per_epoch = train_set.num_examples
        print iter_per_epoch
    elif args.dataset == 'Circle':
        print 'loading SPIRAL'
        train_set = Circle(num_examples=20000,
                           classes=1,
                           cycles=1.,
                           noise=0.0,
                           sources=('features', ))
        dataset_train = DataStream.default_stream(
            train_set,
            iteration_scheme=ShuffledScheme(train_set.num_examples,
                                            args.batch_size))
        iter_per_epoch = train_set.num_examples
        print iter_per_epoch
    else:
        raise ValueError("Unknown dataset %s." % args.dataset)

    model_options = locals().copy()

    train_stream = dataset_train

    shp = next(train_stream.get_epoch_iterator())[0].shape
    # make the training data 0 mean and variance 1
    # TODO compute mean and variance on full dataset, not minibatch
    Xbatch = next(train_stream.get_epoch_iterator())[0]
    scl = 1. / np.sqrt(np.mean((Xbatch - np.mean(Xbatch))**2))
    shft = -np.mean(Xbatch * scl)
    print 'Building model'
    params = init_params(model_options)
    if args.reload_ and os.path.exists(args.saveto_filename):
        print 'Reloading Parameters'
        print args.saveto_filename
        params = load_params(args.saveto_filename, params)
    tparams = init_tparams(params)

    x = T.matrix('x', dtype='float32')
    x, cost, cost_each_step = build_model(tparams, model_options)
    inps = [x]

    x_Data = T.matrix('x_Data', dtype='float32')
    temperature = T.scalar('temperature', dtype='float32')
    forward_diffusion, debug_function = one_step_diffusion(
        x_Data, model_options, tparams, temperature)

    print 'Building f_cost...',
    f_cost = theano.function(inps, cost)
    f_cost_each_step = theano.function(inps, cost_each_step)
    print 'Done'
    grads = T.grad(cost, wrt=itemlist(tparams))
    get_grads = theano.function(inps, grads)

    for j in range(0, len(grads)):
        grads[j] = T.switch(T.isnan(grads[j]), T.zeros_like(grads[j]),
                            grads[j])

    # compile the optimizer, the actual computational graph is compiled here
    lr = T.scalar(name='lr')
    print 'Building optimizers...',
    optimizer = args.optimizer

    f_grad_shared, f_update = getattr(optimizers, optimizer)(lr, tparams,
                                                             grads, inps, cost)
    print 'Done'

    print 'Buiding Sampler....'
    f_sample = sample(tparams, model_options)
    print 'Done'

    uidx = 0
    estop = False
    bad_counter = 0
    max_epochs = 4000
    batch_index = 0
    print 'Number of steps....', args.num_steps
    print 'Done'
    count_sample = 1
    for eidx in xrange(max_epochs):
        if eidx % 20 == 0:
            params = unzip(tparams)
            save_params(params,
                        model_dir + '/' + 'params_' + str(eidx) + '.npz')
            if eidx == 30:
                ipdb.set_trace()
        n_samples = 0
        batch_index = 0
        print 'Starting Next Epoch ', eidx

        for data in train_stream.get_epoch_iterator():
            batch_index += 1
            n_samples += len(data[0])
            uidx += 1
            if data[0] is None:
                print 'No data '
                uidx -= 1
                continue
            plt.plot(*zip(*data[0]))
            plt.show()
            ud_start = time.time()
            cost = f_grad_shared(data[0])
            output_cost = f_cost_each_step(data[0])
            output_cost = np.asarray(output_cost)
            #cost_time_step_1, cost_time_step_2 = (output_cost[0,:].sum() *1.0) /args.batch_size, (output_cost[1,:].sum() * 1.0 )/args.batch_size
            print 'Loss is', cost
            f_update(lrate)
            ud = time.time() - ud_start
            if np.isnan(cost) or np.isinf(cost):
                print 'NaN detected'
                return 1.
            gradient_updates_ = get_grads(data[0])

            logger.log({
                'epoch': eidx,
                'batch_index': batch_index,
                'uidx': uidx,
                'training_error': cost
            })

            empty = []
            spiral_x = [empty for i in range(args.num_steps)]
            spiral_corrupted = []
            spiral_sampled = []
            grad_forward = []
            grad_back = []
            x_data_time = []
            x_tilt_time = []
            if eidx % 10 == 0:
                count_sample += 1
                temperature = args.temperature * (args.temperature_factor
                                                  **(args.num_steps - 1))
                temperature_forward = args.temperature
                for num_step in range(args.num_steps):
                    if num_step == 0:
                        x_data_time.append(data[0])
                        plot_images(
                            data[0], model_dir + '/' + 'orig_' + 'epoch_' +
                            str(count_sample) + '_batch_' + str(batch_index))
                        x_data, mu_data = forward_diffusion(
                            data[0], temperature_forward)

                        plot_images(
                            x_data, model_dir + '/' + 'corrupted_' + 'epoch_' +
                            str(count_sample) + '_batch_' + str(batch_index) +
                            '_time_step_' + str(num_step))
                        x_data_time.append(x_data)
                        temp_grad = np.concatenate(
                            (x_data_time[-2], x_data_time[-1]), axis=1)
                        grad_forward.append(temp_grad)

                        x_data = np.asarray(x_data).astype('float32').reshape(
                            args.batch_size, INPUT_SIZE)
                        spiral_corrupted.append(x_data)
                        mu_data = np.asarray(mu_data).astype(
                            'float32').reshape(args.batch_size, INPUT_SIZE)
                        mu_data = mu_data.reshape(args.batch_size, 2)
                    else:
                        x_data_time.append(x_data)
                        x_data, mu_data = forward_diffusion(
                            x_data, temperature_forward)
                        plot_images(
                            x_data, model_dir + '/' + 'corrupted_' + 'epoch_' +
                            str(count_sample) + '_batch_' + str(batch_index) +
                            '_time_step_' + str(num_step))
                        x_data = np.asarray(x_data).astype('float32').reshape(
                            args.batch_size, INPUT_SIZE)
                        spiral_corrupted.append(x_data)

                        mu_data = np.asarray(mu_data).astype(
                            'float32').reshape(args.batch_size, INPUT_SIZE)
                        mu_data = mu_data.reshape(args.batch_size, 2)
                        x_data_time.append(x_data)
                        temp_grad = np.concatenate(
                            (x_data_time[-2], x_data_time[-1]), axis=1)
                        grad_forward.append(temp_grad)
                    temperature_forward = temperature_forward * args.temperature_factor

                x_temp2 = data[0].reshape(args.batch_size, 2)
                plot_2D(
                    spiral_corrupted, args.num_steps,
                    model_dir + '/' + 'corrupted_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
                plot_2D(
                    x_temp2, 1, model_dir + '/' + 'orig_' + 'epoch_' +
                    str(count_sample) + '_batch_index_' + str(batch_index))
                plot_grad(
                    grad_forward,
                    model_dir + '/' + 'grad_forward_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
                for i in range(args.num_steps + args.extra_steps):
                    x_tilt_time.append(x_data)
                    x_data, sampled_mean = f_sample(x_data, temperature)
                    plot_images(
                        x_data, model_dir + '/' + 'sampled_' + 'epoch_' +
                        str(count_sample) + '_batch_' + str(batch_index) +
                        '_time_step_' + str(i))
                    x_tilt_time.append(x_data)
                    temp_grad = np.concatenate(
                        (x_tilt_time[-2], x_tilt_time[-1]), axis=1)
                    grad_back.append(temp_grad)

                    ###print 'Recons, On step number, using temperature', i, temperature
                    x_data = np.asarray(x_data).astype('float32')
                    x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                    if temperature == args.temperature:
                        temperature = temperature
                    else:
                        temperature /= args.temperature_factor

                plot_grad(
                    grad_back, model_dir + '/' + 'grad_back_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
                np.savez('grad_backward.npz', grad_back)
                plot_2D(
                    x_tilt_time, args.num_steps,
                    model_dir + '/' + 'sampled_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))

                s = np.random.normal(0, 1, [args.batch_size, 2])
                x_sampled = s

                temperature = args.temperature * (args.temperature_factor
                                                  **(args.num_steps - 1))
                x_data = np.asarray(x_sampled).astype('float32')
                for i in range(args.num_steps + args.extra_steps):
                    x_data, sampled_mean = f_sample(x_data, temperature)
                    spiral_sampled.append(x_data)
                    x_data = np.asarray(x_data).astype('float32')
                    x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                    if temperature == args.temperature:
                        temperature = temperature
                    else:
                        temperature /= args.temperature_factor
                plot_2D(
                    spiral_sampled, args.num_steps,
                    model_dir + '/' + 'inference_' + 'epoch_' +
                    str(count_sample) + '_batch_' + str(batch_index))
    ipdb.set_trace()
def train(args, model_args):

    #model_id = '/data/lisatmp4/lambalex/lsun_walkback/walkback_'

    model_id = '/data/lisatmp4/anirudhg/cifar_walk_back/walkback_'
    model_dir = create_log_dir(args, model_id)
    model_id2 = 'logs/walkback_'
    model_dir2 = create_log_dir(args, model_id2)
    print model_dir
    print model_dir2 + '/' + 'log.jsonl.gz'
    logger = mimir.Logger(filename=model_dir2 + '/log.jsonl.gz',
                          formatter=None)

    # TODO batches_per_epoch should not be hard coded
    lrate = args.lr
    import sys
    sys.setrecursionlimit(10000000)
    args, model_args = parse_args()

    #trng = RandomStreams(1234)

    if args.resume_file is not None:
        print "Resuming training from " + args.resume_file
        from blocks.scripts import continue_training
        continue_training(args.resume_file)

    ## load the training data
    if args.dataset == 'MNIST':
        print 'loading MNIST'
        from fuel.datasets import MNIST
        dataset_train = MNIST(['train'], sources=('features', ))
        dataset_test = MNIST(['test'], sources=('features', ))
        n_colors = 1
        spatial_width = 28

    elif args.dataset == 'CIFAR10':
        from fuel.datasets import CIFAR10
        dataset_train = CIFAR10(['train'], sources=('features', ))
        dataset_test = CIFAR10(['test'], sources=('features', ))
        n_colors = 3
        spatial_width = 32

    elif args.dataset == "lsun" or args.dataset == "lsunsmall":

        print "loading lsun class!"

        from load_lsun import load_lsun

        print "loading lsun data!"

        if args.dataset == "lsunsmall":
            dataset_train, dataset_test = load_lsun(args.batch_size,
                                                    downsample=True)
            spatial_width = 32
        else:
            dataset_train, dataset_test = load_lsun(args.batch_size,
                                                    downsample=False)
            spatial_width = 64

        n_colors = 3

    elif args.dataset == "celeba":

        print "loading celeba data"

        from fuel.datasets.celeba import CelebA

        dataset_train = CelebA(which_sets=['train'],
                               which_format="64",
                               sources=('features', ),
                               load_in_memory=False)
        dataset_test = CelebA(which_sets=['test'],
                              which_format="64",
                              sources=('features', ),
                              load_in_memory=False)

        spatial_width = 64
        n_colors = 3

        tr_scheme = SequentialScheme(examples=dataset_train.num_examples,
                                     batch_size=args.batch_size)
        ts_scheme = SequentialScheme(examples=dataset_test.num_examples,
                                     batch_size=args.batch_size)

        train_stream = DataStream.default_stream(dataset_train,
                                                 iteration_scheme=tr_scheme)
        test_stream = DataStream.default_stream(dataset_test,
                                                iteration_scheme=ts_scheme)

        dataset_train = train_stream
        dataset_test = test_stream

        #epoch_it = train_stream.get_epoch_iterator()

    elif args.dataset == 'Spiral':
        print 'loading SPIRAL'
        train_set = Spiral(num_examples=100000,
                           classes=1,
                           cycles=2.,
                           noise=0.01,
                           sources=('features', ))
        dataset_train = DataStream.default_stream(
            train_set,
            iteration_scheme=ShuffledScheme(train_set.num_examples,
                                            args.batch_size))

    else:
        raise ValueError("Unknown dataset %s." % args.dataset)

    model_options = locals().copy()

    if args.dataset != 'lsun' and args.dataset != 'celeba':
        train_stream = Flatten(
            DataStream.default_stream(
                dataset_train,
                iteration_scheme=ShuffledScheme(
                    examples=dataset_train.num_examples -
                    (dataset_train.num_examples % args.batch_size),
                    batch_size=args.batch_size)))
    else:
        train_stream = dataset_train
        test_stream = dataset_test

    print "Width", WIDTH, spatial_width

    shp = next(train_stream.get_epoch_iterator())[0].shape

    print "got epoch iterator"

    Xbatch = next(train_stream.get_epoch_iterator())[0]
    scl = 1. / np.sqrt(np.mean((Xbatch - np.mean(Xbatch))**2))
    shft = -np.mean(Xbatch * scl)

    print 'Building model'
    params = init_params(model_options)
    if args.reload_:
        print "Trying to reload parameters"
        if os.path.exists(args.saveto_filename):
            print 'Reloading Parameters'
            print args.saveto_filename
            params = load_params(args.saveto_filename, params)
    tparams = init_tparams(params)
    print tparams
    x, cost, start_temperature, step_chain = build_model(
        tparams, model_options)
    inps = [x.astype('float32'), start_temperature, step_chain]

    x_Data = T.matrix('x_Data', dtype='float32')
    temperature = T.scalar('temperature', dtype='float32')
    step_chain_part = T.scalar('step_chain_part', dtype='int32')

    forward_diffusion = one_step_diffusion(x_Data, model_options, tparams,
                                           temperature, step_chain_part)

    print tparams
    grads = T.grad(cost, wrt=itemlist(tparams))

    #get_grads = theano.function(inps, grads)

    for j in range(0, len(grads)):
        grads[j] = T.switch(T.isnan(grads[j]), T.zeros_like(grads[j]),
                            grads[j])

    # compile the optimizer, the actual computational graph is compiled here
    lr = T.scalar(name='lr')
    print 'Building optimizers...',
    optimizer = args.optimizer

    f_grad_shared, f_update = getattr(optimizers, optimizer)(lr, tparams,
                                                             grads, inps, cost)
    print 'Done'

    #for param in tparams:
    #    print param
    #    print tparams[param].get_value().shape

    print 'Buiding Sampler....'
    f_sample = sample(tparams, model_options)
    print 'Done'

    uidx = 0
    estop = False
    bad_counter = 0
    max_epochs = 4000
    batch_index = 1
    print 'Number of steps....'
    print args.num_steps
    print "Number of metasteps...."
    print args.meta_steps
    print 'Done'
    count_sample = 1
    save_data = []
    for eidx in xrange(1):
        print 'Starting Next Epoch ', eidx
        for data_ in range(500):  #train_stream.get_epoch_iterator():
            if args.noise == "gaussian":
                x_sampled = np.random.normal(0.5,
                                             2.0,
                                             size=(args.batch_size,
                                                   INPUT_SIZE)).clip(0.0, 1.0)
            else:
                s = np.random.binomial(1, 0.5, INPUT_SIZE)

            temperature = args.temperature * (args.temperature_factor**(
                args.num_steps * args.meta_steps - 1))
            x_data = np.asarray(x_sampled).astype('float32')
            for i in range(args.num_steps * args.meta_steps +
                           args.extra_steps):
                x_data, sampled, sampled_activation, sampled_preactivation = f_sample(
                    x_data.astype('float32'), temperature,
                    args.num_steps * args.meta_steps - i - 1)
                print 'On step number, using temperature', i, temperature
                #reverse_time(scl, shft, x_data, model_dir + '/batch_index_' + str(batch_index) + '_inference_' + 'epoch_' + str(count_sample) + '_step_' + str(i))
                x_data = np.asarray(x_data).astype('float32')
                x_data = x_data.reshape(args.batch_size, INPUT_SIZE)
                if temperature == args.temperature:
                    temperature = temperature
                else:
                    temperature /= args.temperature_factor

            count_sample = count_sample + 1
            save_data.append(x_data)
            fname = model_dir + '/batch_index_' + str(
                batch_index) + '_inference_' + 'epoch_' + str(
                    count_sample) + '_step_' + str(i)
            np.savez(fname + '.npz', x_data)

        save2_data = np.asarray(save_data).astype('float32')
        fname = model_dir + '/generted_images_50000'  #+ args.saveto_filename
        np.savez(fname + '.npz', save_data)

    ipdb.set_trace()