示例#1
0
    def setup_obj(self):
        super(SentibankDataset, self).setup_obj()
        if 'sent' in self.params['obj']:
            self.bc_lookup = get_bc2sent(self.params['dataset'])
        elif self.params['obj'] == 'emo':
            self.bc_lookup = get_bc2emo(self.params['dataset'])
        elif self.params['obj'] == 'bc':
            self.bc_lookup = get_bc2idx(self.params['dataset'])

            # Some bc's are filtered out. The label from prepare_data is the index of the bc.
            # However, this means this index will exceed the output_dim of the network, and the labels won't
            # make any sense. Thus, when reading the label from the tfrecord, we want to map it to a idx
            # from [0, output_dim]
            # This will also be saved.
            self.bc_labelidx2filteredidx = {}
示例#2
0
    def get_idx2label(self):
        """Used to turn indices into human readable labels"""
        label2idx = None
        if self.params['obj'] == 'sent_biclass':
            label2idx = SENT_BICLASS_LABEL2INT
        elif self.params['obj'] == 'sent_triclass':
            label2idx = SENT_TRICLASS_LABEL2INT
        elif self.params['obj'] == 'emo':
            if self.params['dataset'] == 'Sentibank':
                label2idx = SENTIBANK_EMO_LABEL2INT
            elif self.params['dataset'] == 'MVSO':
                label2idx = MVSO_EMO_LABEL2INT
        elif self.params['obj'] == 'bc':
            label2idx = get_bc2idx(self.params['dataset'])

        idx2label = {v: k for k, v in label2idx.items()}
        return idx2label
示例#3
0
    def train(self):
        """Train"""
        self.dataset = get_dataset(self.params)
        self.logger = self._get_logger()
        with tf.Session() as sess:
            # Get data
            self.logger.info('Retrieving training data and setting up graph')
            splits = self.dataset.setup_graph()
            tr_img_batch, self.tr_label_batch = splits['train'][
                'img_batch'], splits['train']['label_batch']
            va_img_batch, va_label_batch = splits['valid'][
                'img_batch'], splits['valid']['label_batch']

            # Get model
            self.output_dim = self.dataset.get_output_dim()
            model = self._get_model(sess, tr_img_batch)
            self.model = model

            # Loss
            self._get_loss(model)

            # Optimize - split into two steps (get gradients and then apply so we can create summary vars)
            optimizer = get_optimizer(self.params)
            grads = tf.gradients(self.loss, tf.trainable_variables())
            self.grads_and_vars = list(zip(grads, tf.trainable_variables()))
            train_step = optimizer.apply_gradients(
                grads_and_vars=self.grads_and_vars)
            # capped_grads_and_vars = [(tf.clip_by_value(gv[0], -5., 5.), gv[1]) for gv in self.grads_and_vars]
            # train_step = optimizer.apply_gradients(grads_and_vars=capped_grads_and_vars)

            # Summary ops and writer
            if self.params['tboard_debug']:
                self.img_batch_for_summ = tr_img_batch
            summary_op = self._get_summary_ops()
            tr_summary_writer = tf.summary.FileWriter(
                self.params['ckpt_dirpath'] + '/train',
                graph=tf.get_default_graph())
            va_summary_writer = tf.summary.FileWriter(
                self.params['ckpt_dirpath'] + '/valid')

            # Initialize after optimization - this needs to be done after adam
            coord, threads = self._initialize(sess)

            if self.params['obj'] == 'bc':
                # labelidx2filteredidx created by dataset
                labelidx2filteredidx = pickle.load(
                    open(
                        os.path.join(self.params['ckpt_dirpath'],
                                     'bc_labelidx2filteredidx.pkl'), 'rb'))
                filteredidx2labelidx = {
                    v: k
                    for k, v in labelidx2filteredidx.items()
                }
                bc2labelidx = get_bc2idx(self.params['dataset'])
                labelidx2bc = {v: k for k, v in bc2labelidx.items()}

            # Training
            saver = tf.train.Saver(max_to_keep=None)
            for i in range(self.params['epochs']):
                self.logger.info('Epoch {}'.format(i))
                # Normally slice_input_producer should have epoch parameter, but it produces a bug when set. So,
                num_tr_batches = self.dataset.get_num_batches('train')
                for j in range(num_tr_batches):
                    # Compute CPU GPU usage timeline
                    compute_timeline = self.params['timeline'] and (j % 100
                                                                    == 0)
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE
                    ) if compute_timeline else None
                    run_metadata = tf.RunMetadata(
                    ) if compute_timeline else None

                    if self.params[
                            'obj'] == 'bc':  # same thing but with topk accuracy
                        _, imgs, last_fc, loss_val, acc_val,\
                        top10_indices, ids, labels,\
                        top5_acc_val, top10_acc_val, summary = sess.run(
                            [train_step, tr_img_batch, model.last_fc, self.loss, self.acc,
                             self.top10_indices, splits['train']['id_batch'], self.tr_label_batch,
                             self.top5_acc, self.top10_acc, summary_op],
                            options=run_options, run_metadata=run_metadata)

                        for loop_i, label_idx in enumerate(labels):
                            print 'Actual: {}, {}, {}. {}'.format(
                                label_idx, filteredidx2labelidx[label_idx],
                                labelidx2bc[filteredidx2labelidx[label_idx]],
                                ids[loop_i])
                            for pred_idx in top10_indices[loop_i]:
                                print '------------> {}, {}, {}'.format(
                                    pred_idx, filteredidx2labelidx[pred_idx],
                                    labelidx2bc[
                                        filteredidx2labelidx[pred_idx]])
                    else:
                        _, imgs, last_fc, loss_val, acc_val, summary = sess.run(
                            [
                                train_step, tr_img_batch, model.last_fc,
                                self.loss, self.acc, summary_op
                            ],
                            options=run_options,
                            run_metadata=run_metadata)

                    self.logger.info(
                        'Train minibatch {} / {} -- Loss: {}'.format(
                            j, num_tr_batches, loss_val))
                    self.logger.info(
                        '................... -- Acc: {}'.format(acc_val))

                    if self.params['obj'] == 'bc':
                        self.logger.info(
                            '................. -- Top-5 Acc: {}'.format(
                                top5_acc_val))
                        self.logger.info(
                            '................. -- Top-10 Acc: {}'.format(
                                top10_acc_val))

                    # Write summary
                    if j % 10 == 0:
                        tr_summary_writer.add_summary(summary,
                                                      i * num_tr_batches + j)

                    # if j == 10:
                    #     break

                    # Save (potentially) before end of epoch just so I don't have to wait
                    if j % 100 == 0:
                        save_model(sess, saver, self.params, i, self.logger)

                        # Create the Timeline object, and write it to a json
                        if self.params['timeline']:
                            tl = timeline.Timeline(run_metadata.step_stats)
                            ctf = tl.generate_chrome_trace_format()
                            timeline_outpath = os.path.join(
                                self.params['ckpt_dirpath'], 'timeline.json')
                            with open(timeline_outpath, 'wb') as tfp:
                                tfp.write(ctf)

                # Evaluate on validation set (potentially)
                if (i + 1) % self.params['val_every_epoch'] == 0:
                    num_va_batches = self.dataset.get_num_batches('valid')
                    for j in range(num_va_batches):
                        img_batch, label_batch = sess.run(
                            [va_img_batch, va_label_batch])
                        loss_val, acc_val, loss_summary, acc_summary = sess.run(
                            [
                                self.loss, self.acc, self.loss_summary,
                                self.acc_summary
                            ],
                            feed_dict={
                                'img_batch:0': img_batch,
                                'label_batch:0': label_batch
                            })

                        self.logger.info(
                            'Valid minibatch {} / {} -- Loss: {}'.format(
                                j, num_va_batches, loss_val))
                        self.logger.info(
                            '................... -- Acc: {}'.format(acc_val))

                        # Write summary
                        if j % 10 == 0:
                            va_summary_writer.add_summary(
                                loss_summary, i * num_tr_batches + j)
                            va_summary_writer.add_summary(
                                acc_summary, i * num_tr_batches + j)

                        # if j == 5:
                        #     break

                # Save model at end of epoch (potentially)
                save_model(sess, saver, self.params, i, self.logger)

            coord.request_stop()
            coord.join(threads)
示例#4
0
    def predict_bc(self):
        """
        Predict for biconcept classification. Pretty much the same as predict(), except it only saves the top k
        biconcepts. Function getting messy so just making a separate one for now.
        """
        self.logger = self._get_logger()

        # labelidx2filteredidx created by dataset
        labelidx2filteredidx = pickle.load(
            open(
                os.path.join(self.params['ckpt_dirpath'],
                             'bc_labelidx2filteredidx.pkl'), 'rb'))
        filteredidx2labelidx = {v: k for k, v in labelidx2filteredidx.items()}
        bc2labelidx = get_bc2idx(self.params['dataset'])
        labelidx2bc = {v: k for k, v in bc2labelidx.items()}

        # If given path contains frames/, just predict for that one video
        # Else walk through directory and predict for every folder that contains frames/
        dirpaths = None
        if os.path.exists(os.path.join(self.params['vid_dirpath'], 'frames')):
            dirpaths = [self.params['vid_dirpath']]
        else:
            dirpaths = self.get_all_vidpaths_with_frames(
                self.params['vid_dirpath'])

        for dirpath in dirpaths:
            # Skip if exists
            # if os.path.exists(os.path.join(dirpath, 'preds', 'sent_biclass_19.csv')):
            #     print 'Skip: {}'.format(dirpath)
            #     continue
            with tf.Session() as sess:
                # Get data
                self.logger.info(
                    'Getting images to predict for {}'.format(dirpath))
                self.dataset = get_dataset(self.params, dirpath)
                self.output_dim = self.dataset.get_output_dim()
                img_batch = self.dataset.setup_graph()

                # Get model
                self.logger.info('Building graph')
                model = self._get_model(sess, img_batch)

                # Initialize
                coord, threads = self._initialize(sess)

                # Restore model now that graph is complete -- loads weights to variables in existing graph
                self.logger.info('Restoring checkpoint')
                saver = load_model(sess, self.params)

                # Make directory to store predictions
                preds_dir = os.path.join(dirpath, 'preds')
                if not os.path.exists(preds_dir):
                    os.mkdir(preds_dir)

                # Predict, write to file
                num_batches = self.dataset.get_num_batches('predict')
                fn = self.params['obj']
                if self.params['load_epoch'] is not None:
                    fn += '_{}'.format(self.params['load_epoch'])
                if self.params['dropout_conf']:
                    fn += '_conf{}'.format(self.params['batch_size'])
                fn += '.csv'

                with open(os.path.join(preds_dir, fn), 'w') as f:
                    for j in range(num_batches):
                        last_fc, probs = sess.run(
                            [model.last_fc, model.probs],
                            feed_dict={'img_batch:0': img_batch.eval()})
                        top10_filteredidxs = np.argpartition(probs[0],
                                                             -10)[-10:][::-1]
                        top10_labelidxs = [
                            filteredidx2labelidx[idx]
                            for idx in top10_filteredidxs
                        ]
                        top10_bc = [
                            labelidx2bc[idx] for idx in top10_labelidxs
                        ]
                        f.write('{}\n'.format(','.join(top10_bc)))

                coord.request_stop()
                coord.join(threads)

            # Clear previous video's graph
            tf.reset_default_graph()
示例#5
0
    def test(self):
        """Test"""
        self.dataset = get_dataset(self.params)
        self.logger = self._get_logger()
        with tf.Session() as sess:
            # Get data
            self.logger.info('Getting test set')
            if self.params['save_preds_for_prog_finetune']:
                self.logger.info(
                    'Saving predictions for prog_finetune: testing on train set'
                )
                splits = self.dataset.setup_graph()
                te_img_batch, self.te_label_batch, te_id_batch = splits['train']['img_batch'], \
                                                                 splits['train']['label_batch'], splits['train']['id_batch']
                num_batches = self.dataset.get_num_batches('train')
                id2pred = {}
            else:
                te_img_batch, self.te_label_batch, te_id_batch = self.dataset.setup_graph(
                )
                num_batches = self.dataset.get_num_batches('test')

            # Get model
            self.logger.info('Building graph')
            self.output_dim = self.dataset.get_output_dim()
            model = self._get_model(sess, te_img_batch)
            self.model = model

            # Loss
            self._get_loss(model)

            # Weights and gradients
            grads = tf.gradients(self.loss, tf.trainable_variables())
            self.grads_and_vars = list(zip(grads, tf.trainable_variables()))

            # Summary ops and writer
            summary_op = self._get_summary_ops()
            summary_writer = tf.summary.FileWriter(
                self.params['ckpt_dirpath'] + '/test',
                graph=tf.get_default_graph())

            # Initialize
            coord, threads = self._initialize(sess)

            # Restore model now that graph is complete -- loads weights to variables in existing graph
            self.logger.info('Restoring checkpoint')
            saver = load_model(sess, self.params)

            # print sess.run(tf.trainable_variables()[0])

            if self.params['obj'] == 'bc':
                # labelidx2filteredidx created by dataset
                labelidx2filteredidx = pickle.load(
                    open(
                        os.path.join(self.params['ckpt_dirpath'],
                                     'bc_labelidx2filteredidx.pkl'), 'rb'))
                filteredidx2labelidx = {
                    v: k
                    for k, v in labelidx2filteredidx.items()
                }
                bc2labelidx = get_bc2idx(self.params['dataset'])
                labelidx2bc = {v: k for k, v in bc2labelidx.items()}

            # Test
            overall_correct = 0
            overall_num = 0
            if self.params['obj'] == 'bc':
                top5_overall_correct = 0
                top10_overall_correct = 0
            for j in range(num_batches):
                if self.params['save_preds_for_prog_finetune']:
                    probs, ids, loss_val, acc_val, summary = sess.run([
                        model.probs, te_id_batch, self.loss, self.acc,
                        summary_op
                    ])
                    for k in range(len(ids)):
                        id2pred[ids[k]] = probs[k]
                elif self.params['scramble_img_mode']:
                    fn = {
                        'uniform': scramble_img,
                        'recursive': scramble_img_recursively
                    }[self.params['scramble_img_mode']]
                    img_batch, label_batch = sess.run(
                        [te_img_batch, self.te_label_batch])
                    for k in range(len(img_batch)):
                        img_batch[k] = fn(img_batch[k],
                                          self.params['scramble_blocksize'])
                    loss_val, acc_val, summary = sess.run(
                        [self.loss, self.acc, summary_op],
                        feed_dict={
                            'img_batch:0': img_batch,
                            'label_batch:0': label_batch
                        })
                elif self.params['obj'] == 'bc':

                    probs, loss_val, acc_val, top5_acc_val, top10_acc_val, top10_indices,\
                        ids, labels, summary \
                        = sess.run(
                        [model.probs, self.loss, self.acc, self.top5_acc, self.top10_acc,
                         self.top10_indices, te_id_batch, self.te_label_batch, summary_op
                         ])
                    # probs, loss_val, acc_val, top5_acc_val, top10_acc_val, labels, summary = sess.run(
                    #     [model.probs, self.loss, self.acc, self.top5_acc, self.top10_acc, self.te_label_batch,
                    #      summary_op])

                    for loop_idx in range(25):
                        print probs[0][loop_idx]
                else:
                    loss_val, acc_val, summary = sess.run(
                        [self.loss, self.acc, summary_op])
                    labels, ids, last_fc, probs = sess.run([
                        self.te_label_batch, te_id_batch, model.last_fc,
                        model.probs
                    ])
                    for i in range(len(probs)):
                        print probs[i], labels[i], ids[i]
                    # print probs[0]

                overall_correct += int(acc_val *
                                       te_img_batch.get_shape().as_list()[0])
                overall_num += te_img_batch.get_shape().as_list()[0]
                overall_acc = float(overall_correct) / overall_num

                self.logger.info('Test minibatch {} / {} -- Loss: {}'.format(
                    j, num_batches, loss_val))
                self.logger.info(
                    '................... -- Acc: {}'.format(acc_val))
                if self.params['obj'] == 'bc':
                    self.logger.info(
                        '................... -- Top-5 Acc: {}'.format(
                            top5_acc_val))
                    self.logger.info(
                        '................... -- Top-10 Acc: {}'.format(
                            top10_acc_val))
                self.logger.info(
                    '................... -- Overall acc: {}'.format(
                        overall_acc))

                if self.params['obj'] == 'bc':
                    top5_overall_correct += int(
                        top5_acc_val * te_img_batch.get_shape().as_list()[0])
                    top10_overall_correct += int(
                        top10_acc_val * te_img_batch.get_shape().as_list()[0])
                    top5_overall_acc = float(
                        top5_overall_correct) / overall_num
                    top10_overall_acc = float(
                        top10_overall_correct) / overall_num
                    self.logger.info(
                        '................... -- Top5 overall acc: {}'.format(
                            top5_overall_acc))
                    self.logger.info(
                        '................... -- Top10 overall acc: {}'.format(
                            top10_overall_acc))
                    for loop_i, label_idx in enumerate(labels):
                        print 'Actual: {}, {}, {}. {}'.format(
                            label_idx, filteredidx2labelidx[label_idx],
                            labelidx2bc[filteredidx2labelidx[label_idx]],
                            ids[loop_i])
                        for pred_idx in top10_indices[loop_i]:
                            print '------------> {}, {}, {}'.format(
                                pred_idx, filteredidx2labelidx[pred_idx],
                                labelidx2bc[filteredidx2labelidx[pred_idx]])

                # Write summary
                if j % 10 == 0:
                    summary_writer.add_summary(summary, j)

            if self.params['save_preds_for_prog_finetune']:
                with open(
                        os.path.join(
                            self.params['ckpt_dirpath'],
                            '{}-{}-id2pred.pkl'.format(self.params['dataset'],
                                                       self.params['obj'])),
                        'w') as f:
                    pickle.dump(id2pred, f, protocol=2)

            coord.request_stop()
            coord.join(threads)