disc_gen_out = T.nnet.sigmoid(disc_out[:BATCH_SIZE]) disc_inputs = T.nnet.sigmoid(disc_out[BATCH_SIZE:]) # Gen objective: push D(G) to one gen_cost = T.nnet.binary_crossentropy(disc_gen_out, swft.floatX(1)).mean() gen_cost.name = 'gen_cost' # Discrim objective: push D(G) to zero, and push D(real) to one discrim_cost = T.nnet.binary_crossentropy(disc_gen_out, swft.floatX(0)).mean() discrim_cost += T.nnet.binary_crossentropy(disc_inputs, swft.floatX(1)).mean() discrim_cost /= swft.floatX(2.0) discrim_cost.name = 'discrim_cost' train_data, dev_data, test_data = swft.mnist.load(BATCH_SIZE) gen_params = swft.search( gen_cost, lambda x: hasattr(x, 'param') and 'Generator' in x.name) discrim_params = swft.search( discrim_cost, lambda x: hasattr(x, 'param') and 'Discriminator' in x.name) _sample_fn = theano.function([], generator(100)) def generate_image(epoch): sample = _sample_fn() # the transpose is rowx, rowy, height, width -> rowy, height, rowx, width sample = sample.reshape((10, 10, 28, 28)).transpose(1, 2, 0, 3).reshape( (10 * 28, 10 * 28)) plt.imshow(sample, cmap=plt.get_cmap('gray'), vmin=0, vmax=1) plt.savefig('epoch' + str(epoch))
sequences = T.imatrix('sequences') transcripts = T.imatrix('transcripts') h0 = T.matrix('h0') frame_level_outputs, new_h0 = predict(sequences, h0) cost = T.nnet.categorical_crossentropy( T.nnet.softmax(frame_level_outputs[:, :-1].reshape((-1, Q_LEVELS))), sequences[:, 1:].flatten() ).mean() cost = cost * swft.floatX(1.44269504089) cost.name = 'cost' params = swft.search(cost, lambda x: hasattr(x, 'param')) swft._train._print_paramsets_info([cost], [params]) grads = T.grad(cost, wrt=params, disconnected_inputs='warn') grads = [T.clip(g, swft.floatX(-GRAD_CLIP), swft.floatX(GRAD_CLIP)) for g in grads] updates = lasagne.updates.adam(grads, params) train_fn = theano.function( [sequences, transcripts, h0], [cost, new_h0], updates=updates, on_unused_input='warn' ) predict_fn = theano.function(
def train( symbolic_inputs, costs, train_data, dev_data=None, test_data=None, param_sets=None, optimizers=[lasagne.updates.adam], print_vars=None, epochs=10, print_every=10, callback=None ): # TODO write documentation if param_sets == None: param_sets = [ swft.search(costs[0], lambda x: hasattr(x, 'param')) ] assert len(costs)==len(param_sets), "train() needs 1 param set per cost!" _print_paramsets_info(costs, param_sets) print "Building updates..." if print_vars is None: print_vars = [c for c in costs] for cost in costs: print_vars += swft.search(cost, lambda x: hasattr(x, '_print')) # Remove duplicate values in print_vars print_vars = list(set(print_vars)) all_updates = [] for cost, params, optimizer in zip(costs, param_sets, optimizers): grads = T.grad(cost, wrt=params) # Clip gradients elementwise grads = [ T.clip(g, swft.floatX(-1.0), swft.floatX(1.0)) for g in grads ] cost_updates = optimizer(grads, params) for k, v in cost_updates.items(): all_updates.append((k,v)) print "Compiling train function..." train_ = theano.function( symbolic_inputs, print_vars, updates=all_updates, on_unused_input='warn' ) print "Compiling evaluate function..." evaluate = theano.function( symbolic_inputs, print_vars, on_unused_input='warn' ) print "Training!" splits = [ ('train', train_, train_data) ] if dev_data is not None: splits.append(('dev', evaluate, dev_data)) if test_data is not None: splits.append(('test', evaluate, test_data)) for epoch in xrange(epochs): for title, fn, data in splits: epoch_totals = [] since_last_print = [] n_inputs = 0 for iteration, inputs in enumerate(data(), start=1): n_inputs += 1 start_time = time.time() outputs_ = fn(*inputs) if iteration == 1: epoch_totals = [o.copy() for o in outputs_] since_last_print = [o.copy() for o in outputs_] else: for i, o in enumerate(outputs_): epoch_totals[i] += o since_last_print[i] += o if iteration % print_every == 0: new_time = time.time() values_to_print = [ ('epoch', epoch), ('input', iteration), ('time_per_input', (time.time() - start_time)) ] for symbolic, totalval in zip(print_vars, since_last_print): values_to_print.append( (str(symbolic), totalval / print_every) ) print "{0}\t".format(title) + "\t".join([ "{0}:{1}".format(name, val) for name, val in values_to_print ]) last_print_time = new_time for i, t in enumerate(since_last_print): since_last_print[i].fill(0) values_to_print = [ ('epoch', epoch), ('n_inputs', n_inputs) ] for symbolic_var, total_val in zip(print_vars, epoch_totals): values_to_print.append( (str(symbolic_var), total_val / n_inputs) ) print "{0} summary\t".format(title) + "\t".join( ["{0}:{1}".format(name, val) for name, val in values_to_print] ) if callback: callback(epoch)
# ... and minimize reconstruction error reconst_cost = T.sqr(reconstructions - images).mean() reconst_cost.name = 'reconst_cost' # this seems to be an important hyperparam, maybe try playing with it more. full_enc_cost = (swft.floatX(100)*reconst_cost) + reg_cost # Decoder objective: minimize reconstruction loss dec_cost = reconst_cost # Discrim objective: push D(latents) to zero, D(noise) to one discrim_cost = T.nnet.binary_crossentropy(discriminator(latents), swft.floatX(0)).mean() discrim_cost += T.nnet.binary_crossentropy(discriminator(noise(BATCH_SIZE)), swft.floatX(1)).mean() discrim_cost.name = 'discrim_cost' enc_params = swft.search(full_enc_cost, lambda x: hasattr(x, 'param') and 'Encoder' in x.name) dec_params = swft.search(dec_cost, lambda x: hasattr(x, 'param') and 'Decoder' in x.name) discrim_params = swft.search(discrim_cost, lambda x: hasattr(x, 'param') and 'Discriminator' in x.name) # Load dataset train_data, dev_data, test_data = swft.mnist.load(BATCH_SIZE) # sample_fn is used by generate_images sample_fn = theano.function( [images], [decoder(noise(100)), decoder(encoder(images[:100])), encoder(images)] ) def generate_images(epoch): """ Save samples and diagnostic images from the model. This function is passed as a callback to `train` and is called after every epoch.
reconst_cost.name = 'reconst_cost' # this seems to be an important hyperparam, maybe try playing with it more. full_enc_cost = (swft.floatX(100) * reconst_cost) + reg_cost # Decoder objective: minimize reconstruction loss dec_cost = reconst_cost # Discrim objective: push D(latents) to zero, D(noise) to one discrim_cost = T.nnet.binary_crossentropy(discriminator(latents), swft.floatX(0)).mean() discrim_cost += T.nnet.binary_crossentropy(discriminator(noise(BATCH_SIZE)), swft.floatX(1)).mean() discrim_cost.name = 'discrim_cost' enc_params = swft.search(full_enc_cost, lambda x: hasattr(x, 'param') and 'Encoder' in x.name) dec_params = swft.search(dec_cost, lambda x: hasattr(x, 'param') and 'Decoder' in x.name) discrim_params = swft.search( discrim_cost, lambda x: hasattr(x, 'param') and 'Discriminator' in x.name) # Load dataset train_data, dev_data, test_data = swft.mnist.load(BATCH_SIZE) # sample_fn is used by generate_images sample_fn = theano.function( [images], [decoder(noise(100)), decoder(encoder(images[:100])), encoder(images)])
disc_gen_out = T.nnet.sigmoid(disc_out[:BATCH_SIZE]) disc_inputs = T.nnet.sigmoid(disc_out[BATCH_SIZE:]) # Gen objective: push D(G) to one gen_cost = T.nnet.binary_crossentropy(disc_gen_out, swft.floatX(1)).mean() gen_cost.name = 'gen_cost' # Discrim objective: push D(G) to zero, and push D(real) to one discrim_cost = T.nnet.binary_crossentropy(disc_gen_out, swft.floatX(0)).mean() discrim_cost += T.nnet.binary_crossentropy(disc_inputs, swft.floatX(1)).mean() discrim_cost /= swft.floatX(2.0) discrim_cost.name = 'discrim_cost' train_data, dev_data, test_data = swft.mnist.load(BATCH_SIZE) gen_params = swft.search(gen_cost, lambda x: hasattr(x, 'param') and 'Generator' in x.name) discrim_params = swft.search(discrim_cost, lambda x: hasattr(x, 'param') and 'Discriminator' in x.name) _sample_fn = theano.function([], generator(100)) def generate_image(epoch): sample = _sample_fn() # the transpose is rowx, rowy, height, width -> rowy, height, rowx, width sample = sample.reshape((10,10,28,28)).transpose(1,2,0,3).reshape((10*28, 10*28)) plt.imshow(sample, cmap = plt.get_cmap('gray'), vmin=0, vmax=1) plt.savefig('epoch'+str(epoch)) swft.train( symbolic_inputs, [gen_cost, discrim_cost], train_data, dev_data=dev_data,
def train(symbolic_inputs, costs, train_data, dev_data=None, test_data=None, param_sets=None, optimizers=[lasagne.updates.adam], print_vars=None, epochs=10, print_every=10, callback=None): # TODO write documentation if param_sets == None: param_sets = [swft.search(costs[0], lambda x: hasattr(x, 'param'))] assert len(costs) == len(param_sets), "train() needs 1 param set per cost!" _print_paramsets_info(costs, param_sets) print "Building updates..." if print_vars is None: print_vars = [c for c in costs] for cost in costs: print_vars += swft.search(cost, lambda x: hasattr(x, '_print')) # Remove duplicate values in print_vars print_vars = list(set(print_vars)) all_updates = [] for cost, params, optimizer in zip(costs, param_sets, optimizers): grads = T.grad(cost, wrt=params) # Clip gradients elementwise grads = [T.clip(g, swft.floatX(-1.0), swft.floatX(1.0)) for g in grads] cost_updates = optimizer(grads, params) for k, v in cost_updates.items(): all_updates.append((k, v)) print "Compiling train function..." train_ = theano.function(symbolic_inputs, print_vars, updates=all_updates, on_unused_input='warn') print "Compiling evaluate function..." evaluate = theano.function(symbolic_inputs, print_vars, on_unused_input='warn') print "Training!" splits = [('train', train_, train_data)] if dev_data is not None: splits.append(('dev', evaluate, dev_data)) if test_data is not None: splits.append(('test', evaluate, test_data)) for epoch in xrange(epochs): for title, fn, data in splits: epoch_totals = [] since_last_print = [] n_inputs = 0 for iteration, inputs in enumerate(data(), start=1): n_inputs += 1 start_time = time.time() outputs_ = fn(*inputs) if iteration == 1: epoch_totals = [o.copy() for o in outputs_] since_last_print = [o.copy() for o in outputs_] else: for i, o in enumerate(outputs_): epoch_totals[i] += o since_last_print[i] += o if iteration % print_every == 0: new_time = time.time() values_to_print = [('epoch', epoch), ('input', iteration), ('time_per_input', (time.time() - start_time))] for symbolic, totalval in zip(print_vars, since_last_print): values_to_print.append( (str(symbolic), totalval / print_every)) print "{0}\t".format(title) + "\t".join([ "{0}:{1}".format(name, val) for name, val in values_to_print ]) last_print_time = new_time for i, t in enumerate(since_last_print): since_last_print[i].fill(0) values_to_print = [('epoch', epoch), ('n_inputs', n_inputs)] for symbolic_var, total_val in zip(print_vars, epoch_totals): values_to_print.append( (str(symbolic_var), total_val / n_inputs)) print "{0} summary\t".format(title) + "\t".join( ["{0}:{1}".format(name, val) for name, val in values_to_print]) if callback: callback(epoch)