def setup_obj(self): super(SentibankDataset, self).setup_obj() if 'sent' in self.params['obj']: self.bc_lookup = get_bc2sent(self.params['dataset']) elif self.params['obj'] == 'emo': self.bc_lookup = get_bc2emo(self.params['dataset']) elif self.params['obj'] == 'bc': self.bc_lookup = get_bc2idx(self.params['dataset']) # Some bc's are filtered out. The label from prepare_data is the index of the bc. # However, this means this index will exceed the output_dim of the network, and the labels won't # make any sense. Thus, when reading the label from the tfrecord, we want to map it to a idx # from [0, output_dim] # This will also be saved. self.bc_labelidx2filteredidx = {}
def get_idx2label(self): """Used to turn indices into human readable labels""" label2idx = None if self.params['obj'] == 'sent_biclass': label2idx = SENT_BICLASS_LABEL2INT elif self.params['obj'] == 'sent_triclass': label2idx = SENT_TRICLASS_LABEL2INT elif self.params['obj'] == 'emo': if self.params['dataset'] == 'Sentibank': label2idx = SENTIBANK_EMO_LABEL2INT elif self.params['dataset'] == 'MVSO': label2idx = MVSO_EMO_LABEL2INT elif self.params['obj'] == 'bc': label2idx = get_bc2idx(self.params['dataset']) idx2label = {v: k for k, v in label2idx.items()} return idx2label
def train(self): """Train""" self.dataset = get_dataset(self.params) self.logger = self._get_logger() with tf.Session() as sess: # Get data self.logger.info('Retrieving training data and setting up graph') splits = self.dataset.setup_graph() tr_img_batch, self.tr_label_batch = splits['train'][ 'img_batch'], splits['train']['label_batch'] va_img_batch, va_label_batch = splits['valid'][ 'img_batch'], splits['valid']['label_batch'] # Get model self.output_dim = self.dataset.get_output_dim() model = self._get_model(sess, tr_img_batch) self.model = model # Loss self._get_loss(model) # Optimize - split into two steps (get gradients and then apply so we can create summary vars) optimizer = get_optimizer(self.params) grads = tf.gradients(self.loss, tf.trainable_variables()) self.grads_and_vars = list(zip(grads, tf.trainable_variables())) train_step = optimizer.apply_gradients( grads_and_vars=self.grads_and_vars) # capped_grads_and_vars = [(tf.clip_by_value(gv[0], -5., 5.), gv[1]) for gv in self.grads_and_vars] # train_step = optimizer.apply_gradients(grads_and_vars=capped_grads_and_vars) # Summary ops and writer if self.params['tboard_debug']: self.img_batch_for_summ = tr_img_batch summary_op = self._get_summary_ops() tr_summary_writer = tf.summary.FileWriter( self.params['ckpt_dirpath'] + '/train', graph=tf.get_default_graph()) va_summary_writer = tf.summary.FileWriter( self.params['ckpt_dirpath'] + '/valid') # Initialize after optimization - this needs to be done after adam coord, threads = self._initialize(sess) if self.params['obj'] == 'bc': # labelidx2filteredidx created by dataset labelidx2filteredidx = pickle.load( open( os.path.join(self.params['ckpt_dirpath'], 'bc_labelidx2filteredidx.pkl'), 'rb')) filteredidx2labelidx = { v: k for k, v in labelidx2filteredidx.items() } bc2labelidx = get_bc2idx(self.params['dataset']) labelidx2bc = {v: k for k, v in bc2labelidx.items()} # Training saver = tf.train.Saver(max_to_keep=None) for i in range(self.params['epochs']): self.logger.info('Epoch {}'.format(i)) # Normally slice_input_producer should have epoch parameter, but it produces a bug when set. So, num_tr_batches = self.dataset.get_num_batches('train') for j in range(num_tr_batches): # Compute CPU GPU usage timeline compute_timeline = self.params['timeline'] and (j % 100 == 0) run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE ) if compute_timeline else None run_metadata = tf.RunMetadata( ) if compute_timeline else None if self.params[ 'obj'] == 'bc': # same thing but with topk accuracy _, imgs, last_fc, loss_val, acc_val,\ top10_indices, ids, labels,\ top5_acc_val, top10_acc_val, summary = sess.run( [train_step, tr_img_batch, model.last_fc, self.loss, self.acc, self.top10_indices, splits['train']['id_batch'], self.tr_label_batch, self.top5_acc, self.top10_acc, summary_op], options=run_options, run_metadata=run_metadata) for loop_i, label_idx in enumerate(labels): print 'Actual: {}, {}, {}. {}'.format( label_idx, filteredidx2labelidx[label_idx], labelidx2bc[filteredidx2labelidx[label_idx]], ids[loop_i]) for pred_idx in top10_indices[loop_i]: print '------------> {}, {}, {}'.format( pred_idx, filteredidx2labelidx[pred_idx], labelidx2bc[ filteredidx2labelidx[pred_idx]]) else: _, imgs, last_fc, loss_val, acc_val, summary = sess.run( [ train_step, tr_img_batch, model.last_fc, self.loss, self.acc, summary_op ], options=run_options, run_metadata=run_metadata) self.logger.info( 'Train minibatch {} / {} -- Loss: {}'.format( j, num_tr_batches, loss_val)) self.logger.info( '................... -- Acc: {}'.format(acc_val)) if self.params['obj'] == 'bc': self.logger.info( '................. -- Top-5 Acc: {}'.format( top5_acc_val)) self.logger.info( '................. -- Top-10 Acc: {}'.format( top10_acc_val)) # Write summary if j % 10 == 0: tr_summary_writer.add_summary(summary, i * num_tr_batches + j) # if j == 10: # break # Save (potentially) before end of epoch just so I don't have to wait if j % 100 == 0: save_model(sess, saver, self.params, i, self.logger) # Create the Timeline object, and write it to a json if self.params['timeline']: tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() timeline_outpath = os.path.join( self.params['ckpt_dirpath'], 'timeline.json') with open(timeline_outpath, 'wb') as tfp: tfp.write(ctf) # Evaluate on validation set (potentially) if (i + 1) % self.params['val_every_epoch'] == 0: num_va_batches = self.dataset.get_num_batches('valid') for j in range(num_va_batches): img_batch, label_batch = sess.run( [va_img_batch, va_label_batch]) loss_val, acc_val, loss_summary, acc_summary = sess.run( [ self.loss, self.acc, self.loss_summary, self.acc_summary ], feed_dict={ 'img_batch:0': img_batch, 'label_batch:0': label_batch }) self.logger.info( 'Valid minibatch {} / {} -- Loss: {}'.format( j, num_va_batches, loss_val)) self.logger.info( '................... -- Acc: {}'.format(acc_val)) # Write summary if j % 10 == 0: va_summary_writer.add_summary( loss_summary, i * num_tr_batches + j) va_summary_writer.add_summary( acc_summary, i * num_tr_batches + j) # if j == 5: # break # Save model at end of epoch (potentially) save_model(sess, saver, self.params, i, self.logger) coord.request_stop() coord.join(threads)
def predict_bc(self): """ Predict for biconcept classification. Pretty much the same as predict(), except it only saves the top k biconcepts. Function getting messy so just making a separate one for now. """ self.logger = self._get_logger() # labelidx2filteredidx created by dataset labelidx2filteredidx = pickle.load( open( os.path.join(self.params['ckpt_dirpath'], 'bc_labelidx2filteredidx.pkl'), 'rb')) filteredidx2labelidx = {v: k for k, v in labelidx2filteredidx.items()} bc2labelidx = get_bc2idx(self.params['dataset']) labelidx2bc = {v: k for k, v in bc2labelidx.items()} # If given path contains frames/, just predict for that one video # Else walk through directory and predict for every folder that contains frames/ dirpaths = None if os.path.exists(os.path.join(self.params['vid_dirpath'], 'frames')): dirpaths = [self.params['vid_dirpath']] else: dirpaths = self.get_all_vidpaths_with_frames( self.params['vid_dirpath']) for dirpath in dirpaths: # Skip if exists # if os.path.exists(os.path.join(dirpath, 'preds', 'sent_biclass_19.csv')): # print 'Skip: {}'.format(dirpath) # continue with tf.Session() as sess: # Get data self.logger.info( 'Getting images to predict for {}'.format(dirpath)) self.dataset = get_dataset(self.params, dirpath) self.output_dim = self.dataset.get_output_dim() img_batch = self.dataset.setup_graph() # Get model self.logger.info('Building graph') model = self._get_model(sess, img_batch) # Initialize coord, threads = self._initialize(sess) # Restore model now that graph is complete -- loads weights to variables in existing graph self.logger.info('Restoring checkpoint') saver = load_model(sess, self.params) # Make directory to store predictions preds_dir = os.path.join(dirpath, 'preds') if not os.path.exists(preds_dir): os.mkdir(preds_dir) # Predict, write to file num_batches = self.dataset.get_num_batches('predict') fn = self.params['obj'] if self.params['load_epoch'] is not None: fn += '_{}'.format(self.params['load_epoch']) if self.params['dropout_conf']: fn += '_conf{}'.format(self.params['batch_size']) fn += '.csv' with open(os.path.join(preds_dir, fn), 'w') as f: for j in range(num_batches): last_fc, probs = sess.run( [model.last_fc, model.probs], feed_dict={'img_batch:0': img_batch.eval()}) top10_filteredidxs = np.argpartition(probs[0], -10)[-10:][::-1] top10_labelidxs = [ filteredidx2labelidx[idx] for idx in top10_filteredidxs ] top10_bc = [ labelidx2bc[idx] for idx in top10_labelidxs ] f.write('{}\n'.format(','.join(top10_bc))) coord.request_stop() coord.join(threads) # Clear previous video's graph tf.reset_default_graph()
def test(self): """Test""" self.dataset = get_dataset(self.params) self.logger = self._get_logger() with tf.Session() as sess: # Get data self.logger.info('Getting test set') if self.params['save_preds_for_prog_finetune']: self.logger.info( 'Saving predictions for prog_finetune: testing on train set' ) splits = self.dataset.setup_graph() te_img_batch, self.te_label_batch, te_id_batch = splits['train']['img_batch'], \ splits['train']['label_batch'], splits['train']['id_batch'] num_batches = self.dataset.get_num_batches('train') id2pred = {} else: te_img_batch, self.te_label_batch, te_id_batch = self.dataset.setup_graph( ) num_batches = self.dataset.get_num_batches('test') # Get model self.logger.info('Building graph') self.output_dim = self.dataset.get_output_dim() model = self._get_model(sess, te_img_batch) self.model = model # Loss self._get_loss(model) # Weights and gradients grads = tf.gradients(self.loss, tf.trainable_variables()) self.grads_and_vars = list(zip(grads, tf.trainable_variables())) # Summary ops and writer summary_op = self._get_summary_ops() summary_writer = tf.summary.FileWriter( self.params['ckpt_dirpath'] + '/test', graph=tf.get_default_graph()) # Initialize coord, threads = self._initialize(sess) # Restore model now that graph is complete -- loads weights to variables in existing graph self.logger.info('Restoring checkpoint') saver = load_model(sess, self.params) # print sess.run(tf.trainable_variables()[0]) if self.params['obj'] == 'bc': # labelidx2filteredidx created by dataset labelidx2filteredidx = pickle.load( open( os.path.join(self.params['ckpt_dirpath'], 'bc_labelidx2filteredidx.pkl'), 'rb')) filteredidx2labelidx = { v: k for k, v in labelidx2filteredidx.items() } bc2labelidx = get_bc2idx(self.params['dataset']) labelidx2bc = {v: k for k, v in bc2labelidx.items()} # Test overall_correct = 0 overall_num = 0 if self.params['obj'] == 'bc': top5_overall_correct = 0 top10_overall_correct = 0 for j in range(num_batches): if self.params['save_preds_for_prog_finetune']: probs, ids, loss_val, acc_val, summary = sess.run([ model.probs, te_id_batch, self.loss, self.acc, summary_op ]) for k in range(len(ids)): id2pred[ids[k]] = probs[k] elif self.params['scramble_img_mode']: fn = { 'uniform': scramble_img, 'recursive': scramble_img_recursively }[self.params['scramble_img_mode']] img_batch, label_batch = sess.run( [te_img_batch, self.te_label_batch]) for k in range(len(img_batch)): img_batch[k] = fn(img_batch[k], self.params['scramble_blocksize']) loss_val, acc_val, summary = sess.run( [self.loss, self.acc, summary_op], feed_dict={ 'img_batch:0': img_batch, 'label_batch:0': label_batch }) elif self.params['obj'] == 'bc': probs, loss_val, acc_val, top5_acc_val, top10_acc_val, top10_indices,\ ids, labels, summary \ = sess.run( [model.probs, self.loss, self.acc, self.top5_acc, self.top10_acc, self.top10_indices, te_id_batch, self.te_label_batch, summary_op ]) # probs, loss_val, acc_val, top5_acc_val, top10_acc_val, labels, summary = sess.run( # [model.probs, self.loss, self.acc, self.top5_acc, self.top10_acc, self.te_label_batch, # summary_op]) for loop_idx in range(25): print probs[0][loop_idx] else: loss_val, acc_val, summary = sess.run( [self.loss, self.acc, summary_op]) labels, ids, last_fc, probs = sess.run([ self.te_label_batch, te_id_batch, model.last_fc, model.probs ]) for i in range(len(probs)): print probs[i], labels[i], ids[i] # print probs[0] overall_correct += int(acc_val * te_img_batch.get_shape().as_list()[0]) overall_num += te_img_batch.get_shape().as_list()[0] overall_acc = float(overall_correct) / overall_num self.logger.info('Test minibatch {} / {} -- Loss: {}'.format( j, num_batches, loss_val)) self.logger.info( '................... -- Acc: {}'.format(acc_val)) if self.params['obj'] == 'bc': self.logger.info( '................... -- Top-5 Acc: {}'.format( top5_acc_val)) self.logger.info( '................... -- Top-10 Acc: {}'.format( top10_acc_val)) self.logger.info( '................... -- Overall acc: {}'.format( overall_acc)) if self.params['obj'] == 'bc': top5_overall_correct += int( top5_acc_val * te_img_batch.get_shape().as_list()[0]) top10_overall_correct += int( top10_acc_val * te_img_batch.get_shape().as_list()[0]) top5_overall_acc = float( top5_overall_correct) / overall_num top10_overall_acc = float( top10_overall_correct) / overall_num self.logger.info( '................... -- Top5 overall acc: {}'.format( top5_overall_acc)) self.logger.info( '................... -- Top10 overall acc: {}'.format( top10_overall_acc)) for loop_i, label_idx in enumerate(labels): print 'Actual: {}, {}, {}. {}'.format( label_idx, filteredidx2labelidx[label_idx], labelidx2bc[filteredidx2labelidx[label_idx]], ids[loop_i]) for pred_idx in top10_indices[loop_i]: print '------------> {}, {}, {}'.format( pred_idx, filteredidx2labelidx[pred_idx], labelidx2bc[filteredidx2labelidx[pred_idx]]) # Write summary if j % 10 == 0: summary_writer.add_summary(summary, j) if self.params['save_preds_for_prog_finetune']: with open( os.path.join( self.params['ckpt_dirpath'], '{}-{}-id2pred.pkl'.format(self.params['dataset'], self.params['obj'])), 'w') as f: pickle.dump(id2pred, f, protocol=2) coord.request_stop() coord.join(threads)