def trainloop(args_dict, model, suff_name='', model_val=None, epoch_start=0): ## DataLoaders dataloader = DataLoader(args_dict) N_train, N_val, N_test = dataloader.get_dataset_size() train_gen = dataloader.generator('train', args_dict.bs) val_gen = dataloader.generator('val', args_dict.bs) if args_dict.es_metric == 'loss': model_name = os.path.join( args_dict.data_folder, 'models', args_dict.model_name + suff_name + '_weights.{epoch:02d}-{val_loss:.2f}.h5') ep = EarlyStopping(monitor='val_loss', patience=args_dict.pat, verbose=0, mode='auto') mc = ModelCheckpoint(model_name, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto') tb = TensorBoard(log_dir='./logs_tb') # reset states after each batch (bcs stateful) rs = ResetStatesCallback() model.fit_generator(train_gen, nb_epoch=args_dict.nepochs, samples_per_epoch=N_train, validation_data=val_gen, nb_val_samples=N_val, callbacks=[mc, ep, rs], verbose=1, nb_worker=args_dict.workers, pickle_safe=False) else: # models saved based on other metrics - manual train loop # validation generator in test mode to output image names val_gen_test = dataloader.generator('val', args_dict.bs, train_flag=False) # load vocab to convert captions to words and compute cider data = json.load( open( os.path.join(args_dict.data_folder, 'data', args_dict.json_file), 'r')) vocab_src = data['ix_to_word'] inv_vocab = {} for idx in vocab_src.keys(): inv_vocab[int(idx)] = vocab_src[idx] vocab = {v: k for k, v in inv_vocab.items()} # init waiting param and best metric values wait = 0 best_metric = -np.inf for e in range(args_dict.nepochs): print("Epoch %d/%d" % (e + 1 + epoch_start, args_dict.nepochs + epoch_start)) prog = Progbar(target=N_train) samples = 0 for x, y, sw in train_gen: # do one epoch loss = model.train_on_batch(x=x, y=y, sample_weight=sw) model.reset_states() samples += args_dict.bs if samples >= N_train: break prog.update(current=samples, values=[('loss', loss)]) # forward val images to get loss samples = 0 val_losses = [] for x, y, sw in val_gen: val_losses.append(model.test_on_batch(x, y, sw)) model.reset_states() samples += args_dict.bs if samples > N_val: break # forward val images to get captions and compute metric # this can either be done with true prev words or gen prev words: # args_dict.es_prev_words to 'gt' oget_modelr 'gen' if args_dict.es_prev_words == 'gt': results_file = gencaps(args_dict, cnn, lang_model, val_gen_test, inv_vocab, N_val) else: aux_model = os.path.join(args_dict.data_folder, 'tmp', args_dict.model_name + '_aux.h5') model.save_weights(aux_model, overwrite=True) model_val.load_weights(aux_model) results_file = gencaps(args_dict, model_val, val_gen_test, inv_vocab, N_val) # get merged ground truth file to eval caps ann_file = './utils/captions_merged.json' # score captions and return requested metric metric = get_metric(args_dict, results_file, ann_file) prog.update(current=N_train, values=[('loss', loss), ('val_loss', np.mean(val_losses)), (args_dict.es_metric, metric)]) # decide if we save checkpoint and/or stop training if metric > best_metric: best_metric = metric wait = 0 model_name = os.path.join( args_dict.data_folder, 'models', args_dict.model_name + suff_name + '_weights_e' + str(e) + '_' + args_dict.es_metric + "%0.2f" % metric + '.h5') model.save_weights(model_name) else: wait += 1 if wait > args_dict.pat: break args_dict.mode = 'train' return model, model_name
model.compile(optimizer=opt, loss='categorical_crossentropy') # load vocab to convert captions to words and compute cider data = json.load( open(os.path.join(args_dict.data_folder, 'data', args_dict.json_file), 'r')) vocab_src = data['ix_to_word'] inv_vocab = {} for idx in vocab.keys(): inv_vocab[int(idx)] = vocab_src[idx] vocab = {v: k for k, v in inv_vocab.items()} dataloader = DataLoader(args_dict) N_train, N_val, N_test, _ = dataloader.get_dataset_size() N = args_dict.bs gen = dataloader.generator('test', batch_size=args_dict.bs, train_flag=False) captions = [] num_samples = 0 print_every = 100 t = time.time() for [ims, prevs], caps, _, imids in gen: if args_dict.bsize > 1: # beam search word_idxs = np.zeros((args_dict.bsize, args_dict.seqlen)) word_idxs[:, :] = 2 ### beam search caps ### conv_feats = cnn.predict_on_batch(ims) seqs, scores = beamsearch(model=lang_model, image=conv_feats,