def RNNGenCost(batch, model, params, misc): """ cost function, returns cost and gradients for model """ regc = params['regc'] # regularization cost BatchGenerator = batchDecodeGenerator(params) wordtoix = misc['wordtoix'] print 'batch...' print len(batch) # forward the RNN on each image sentence pair # the generator returns a list of matrices that have word probabilities # and a list of cache objects that will be needed for backprop Ys, gen_caches = BatchGenerator.forward(batch, model, params, misc, predict_mode = False) print 'batch generator output ' , np.shape(Ys) # compute softmax costs for all generated sentences, and the gradients on top loss_cost = 0.0 dYs = [] logppl = 0.0 logppln = 0 for i,pair in enumerate(batch): img = pair['image'] # ground truth indeces for this sentence we expect to see gtix = [ wordtoix[w] for w in pair['sentence']['tokens'] if w in wordtoix ] gtix.append(0) # don't forget END token must be predicted in the end! # fetch the predicted probabilities, as rows Y = Ys[i] maxes = np.amax(Y, axis=1, keepdims=True) #return the maxvalue of every row #print np.shape(Y) # n * vocab_word_num e = np.exp(Y - maxes) # for numerical stability shift into good numerical range P = e / np.sum(e, axis=1, keepdims=True) # the probabilities for every token with each word loss_cost += - np.sum(np.log(1e-20 + P[range(len(gtix)),gtix])) # note: add smoothing to not get infs logppl += - np.sum(np.log2(1e-20 + P[range(len(gtix)),gtix])) # also accumulate log2 perplexities logppln += len(gtix) # lets be clever and optimize for speed here to derive the gradient in place quickly for iy,y in enumerate(gtix): P[iy,y] -= 1 # softmax derivatives are pretty simple dYs.append(P) # backprop the RNN #dYs is the value of error, gen_caches is the used weights grads = BatchGenerator.backward(dYs, gen_caches) # add L2 regularization cost and gradients reg_cost = 0.0 if regc > 0: for p in misc['regularize']: mat = model[p] reg_cost += 0.5 * regc * np.sum(mat * mat) grads[p] += regc * mat # normalize the cost and gradient by the batch size batch_size = len(batch) reg_cost /= batch_size loss_cost /= batch_size for k in grads: grads[k] /= batch_size # return output in json out = {} out['cost'] = {'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost} out['grad'] = grads out['stats'] = { 'ppl2' : 2 ** (logppl / logppln)} return out
def main(params): batch_size = params['batch_size'] dataset = params['dataset'] word_count_threshold = params['word_count_threshold'] do_grad_check = params['do_grad_check'] max_epochs = params['max_epochs'] host = socket.gethostname() # get computer hostname # fetch the data provider dp = getDataProvider(dataset) misc = {} # stores various misc items that need to be passed around the framework # go over all training sentences and find the vocabulary we want to use, i.e. the words that occur # at least word_count_threshold number of times misc['wordtoix'], misc['ixtoword'], bias_init_vector = preProBuildWordVocab(dp.iterSentences('train'), word_count_threshold) # print '...' # print misc['wordtoix'] # delegate the initialization of the model to the Generator class BatchGenerator = batchDecodeGenerator(params) init_struct = BatchGenerator.init(params, misc) model, misc['update'], misc['regularize'] = (init_struct['model'], init_struct['update'], init_struct['regularize']) # force overwrite here. This is a bit of a hack, not happy about it model['bd'] = bias_init_vector.reshape(1, bias_init_vector.size) print 'model init done.' print 'model has keys: ' + ', '.join(model.keys()) print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['update']) print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['regularize']) print 'number of learnable parameters total: %d' % (sum(model[k].shape[0] * model[k].shape[1] for k in misc['update']), ) if params.get('init_model_from', ''): # load checkpoint checkpoint = pickle.load(open(params['init_model_from'], 'rb')) model = checkpoint['model'] # overwrite the model # initialize the Solver and the cost function solver = Solver() def costfun(batch, model): # wrap the cost function to abstract some things away from the Solver return RNNGenCost(batch, model, params, misc) # calculate how many iterations we need num_sentences_total = dp.getSplitSize('train', ofwhat = 'sentences') num_iters_one_epoch = num_sentences_total / batch_size max_iters = max_epochs * num_iters_one_epoch eval_period_in_epochs = params['eval_period'] eval_period_in_iters = max(1, int(num_iters_one_epoch * eval_period_in_epochs)) abort = False top_val_ppl2 = -1 smooth_train_ppl2 = len(misc['ixtoword']) # initially size of dictionary of confusion val_ppl2 = len(misc['ixtoword']) last_status_write_time = 0 # for writing worker job status reports json_worker_status = {} json_worker_status['params'] = params json_worker_status['history'] = [] for it in xrange(max_iters): if abort: break t0 = time.time() # fetch a batch of data batch = [dp.sampleImageSentencePair() for i in xrange(batch_size)] # evaluate cost, gradient and perform parameter update step_struct = solver.step(batch, model, costfun, **params) cost = step_struct['cost'] dt = time.time() - t0 # print training statistics train_ppl2 = step_struct['stats']['ppl2'] smooth_train_ppl2 = 0.99 * smooth_train_ppl2 + 0.01 * train_ppl2 # smooth exponentially decaying moving average if it == 0: smooth_train_ppl2 = train_ppl2 # start out where we start out epoch = it * 1.0 / num_iters_one_epoch print '%d/%d batch done in %.3fs. at epoch %.2f. loss cost = %f, reg cost = %f, ppl2 = %.2f (smooth %.2f)' \ % (it, max_iters, dt, epoch, cost['loss_cost'], cost['reg_cost'], \ train_ppl2, smooth_train_ppl2) # perform gradient check if desired, with a bit of a burnin time (10 iterations) if it == 10 and do_grad_check: print 'disabling dropout for gradient check...' params['drop_prob_encoder'] = 0 params['drop_prob_decoder'] = 0 solver.gradCheck(batch, model, costfun) print 'done gradcheck, exitting.' sys.exit() # hmmm. probably should exit here # detect if loss is exploding and kill the job if so total_cost = cost['total_cost'] if it == 0: total_cost0 = total_cost # store this initial cost if total_cost > total_cost0 * 2: print 'Aboring, cost seems to be exploding. Run gradcheck? Lower the learning rate?' abort = True # set the abort flag, we'll break out # logging: write JSON files for visual inspection of the training tnow = time.time() if tnow > last_status_write_time + 60*1: # every now and then lets write a report last_status_write_time = tnow jstatus = {} jstatus['time'] = datetime.datetime.now().isoformat() jstatus['iter'] = (it, max_iters) jstatus['epoch'] = (epoch, max_epochs) jstatus['time_per_batch'] = dt jstatus['smooth_train_ppl2'] = smooth_train_ppl2 jstatus['val_ppl2'] = val_ppl2 # just write the last available one jstatus['train_ppl2'] = train_ppl2 json_worker_status['history'].append(jstatus) status_file = os.path.join(params['worker_status_output_directory'], host + '_status.json') try: json.dump(json_worker_status, open(status_file, 'w')) except Exception, e: # todo be more clever here print 'tried to write worker status into %s but got error:' % (status_file, ) print e # perform perplexity evaluation on the validation set and save a model checkpoint if it's good is_last_iter = (it+1) == max_iters if (((it+1) % eval_period_in_iters) == 0 and it < max_iters - 5) or is_last_iter: val_ppl2 = eval_split('val', dp, model, params, misc) # perform the evaluation on VAL set print 'validation perplexity = %f' % (val_ppl2, ) # abort training if the perplexity is no good min_ppl_or_abort = params['min_ppl_or_abort'] if val_ppl2 > min_ppl_or_abort and min_ppl_or_abort > 0: print 'aborting job because validation perplexity %f < %f' % (val_ppl2, min_ppl_or_abort) abort = True # abort the job write_checkpoint_ppl_threshold = params['write_checkpoint_ppl_threshold'] if val_ppl2 < top_val_ppl2 or top_val_ppl2 < 0: if val_ppl2 < write_checkpoint_ppl_threshold or write_checkpoint_ppl_threshold < 0: # if we beat a previous record or if this is the first time # AND we also beat the user-defined threshold or it doesnt exist top_val_ppl2 = val_ppl2 filename = 'model_checkpoint_%s_%s_%s_%.2f.p' % (dataset, host, params['fappend'], val_ppl2) filepath = os.path.join(params['checkpoint_output_directory'], filename) checkpoint = {} checkpoint['it'] = it checkpoint['epoch'] = epoch checkpoint['model'] = model checkpoint['params'] = params checkpoint['perplexity'] = val_ppl2 checkpoint['wordtoix'] = misc['wordtoix'] checkpoint['ixtoword'] = misc['ixtoword'] try: pickle.dump(checkpoint, open(filepath, "wb")) print 'saved checkpoint in %s' % (filepath, ) except Exception, e: # todo be more clever here print 'tried to write checkpoint into %s but got error: ' % (filepat, ) print e