コード例 #1
0
    def build_model(self, batch_queue, tower, opt, scope):
        imgs_train = batch_queue.dequeue()
        tf.summary.image('images/train',
                         montage_tf(imgs_train, 2, 8),
                         max_outputs=1)

        # Create the model
        dec_im, dec_pdrop, layers = self.model.net(
            imgs_train, reuse=True if tower > 0 else None)
        tf.summary.image('images/autoencoder',
                         montage_tf(dec_im, 2, 8),
                         max_outputs=1)
        tf.summary.image('images/generator',
                         montage_tf(dec_pdrop, 2, 8),
                         max_outputs=1)

        # Show the conv_1 filters
        if self.weights_summary:
            with tf.variable_scope('discriminator', reuse=True):
                weights_disc_1 = slim.variable('conv_1/weights')
            tf.summary.image('weights/conv_1',
                             weights_montage(weights_disc_1, 6, 16),
                             max_outputs=1)

        # Compute losses
        disc_loss = self.model.discriminator_loss(scope, tower)
        gen_loss = self.model.generator_loss(imgs_train, scope, tower)
        tf.get_variable_scope().reuse_variables()

        # Handle dependencies with update_ops (batch-norm)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            gen_loss = control_flow_ops.with_dependencies([updates], gen_loss)
            disc_loss = control_flow_ops.with_dependencies([updates],
                                                           disc_loss)

        # Calculate the gradients for the batch of data on this tower.
        grads_gen = opt.compute_gradients(gen_loss,
                                          get_variables_to_train('generator'))
        grads_disc = opt.compute_gradients(
            disc_loss, get_variables_to_train('discriminator'))
        grads = grads_gen + grads_disc
        grad_mult = {
            var.op.name: 2.0 if var.op.name.endswith('biases') else 1.0
            for (_, var) in grads
        }
        print('Gradient multipliers: {}'.format(grad_mult))
        grads = tf.contrib.training.multiply_gradients(grads, grad_mult)
        self.summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
        return disc_loss + gen_loss, grads, layers
コード例 #2
0
    def build_model(self, batch_queue, tower, opt, scope):
        imgs_train = batch_queue.dequeue()
        tf.summary.image('images/train', montage_tf(imgs_train, 2, 8), max_outputs=1)

        # Create the model
        dec_im = self.model.net(imgs_train, reuse=True if tower > 0 else None)
        tf.summary.image('images/autoencoder', montage_tf(dec_im, 2, 8), max_outputs=1)

        # Compute losses
        loss = self.model.ae_loss(scope, tower)
        tf.get_variable_scope().reuse_variables()

        # Handle dependencies with update_ops (batch-norm)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            loss = control_flow_ops.with_dependencies([updates], loss)

        # Calculate the gradients for the batch of data on this tower.
        grads = opt.compute_gradients(loss, get_variables_to_train())
        grad_mult = {var.op.name: 2.0 if var.op.name.endswith('biases') else 1.0 for (_, var) in grads}
        print('Gradient multipliers: {}'.format(grad_mult))
        grads = tf.contrib.training.multiply_gradients(grads, grad_mult)
        self.summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
        return loss, grads, None
コード例 #3
0
ファイル: GanTrainer.py プロジェクト: yqGANs/dfgan
    def build_discriminator(self, batch_queue, opt, scope):
        imgs_train, _ = batch_queue.get_next()
        imgs_train.set_shape([self.model.batch_size, ] + self.model.im_shape)

        noise_samples = self.get_noise_sample()
        fake_imgs = self.model.gen(noise_samples)
        tf.summary.image('imgs/train', montage_tf(imgs_train, 4, 16), max_outputs=1)
        tf.summary.image('imgs/fake', montage_tf(fake_imgs, 4, 16), max_outputs=1)

        preds_fake = self.model.disc(fake_imgs, reuse=None)
        preds_real = self.model.disc(imgs_train, reuse=True)

        # Compute losses
        loss = self.model.d_loss(scope, preds_fake, preds_real)
        tf.get_variable_scope().reuse_variables()

        # Handle dependencies with update_ops (batch-norm)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            loss = control_flow_ops.with_dependencies([updates], loss)

        # Calculate the gradients for the batch of data on this tower.
        grads = opt.compute_gradients(loss, get_variables_to_train('discriminator'))

        self.summaries += tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
        return loss, grads, {}
コード例 #4
0
    def build_generator(self, batch_queue, opt, scope):
        noise_samples = self.get_noise_sample()
        fake_imgs = self.model.gen(noise_samples)
        noise = self.model.gen_noise(noise_samples)

        # Create the model
        disc_input_fake = tf.concat([self.make_fake(fake_imgs, noise), fake_imgs], 0)

        preds_disc_fake = self.model.disc(disc_input_fake)

        preds_fake_n, preds_fake = tf.split(preds_disc_fake, 2, 0)

        # Compute losses
        loss_g = self.model.g_loss(scope, preds_fake, preds_fake_n)

        tf.get_variable_scope().reuse_variables()

        # Handle dependencies with update_ops (batch-norm)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            loss_g = control_flow_ops.with_dependencies([updates], loss_g)

        # Calculate the gradients for the batch of data on this tower.
        grads_g = opt.compute_gradients(loss_g, get_variables_to_train('generator'))

        self.summaries += tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
        return loss_g, grads_g, {}
コード例 #5
0
 def make_train_op(self, loss, vars2train=None, scope=None):
     if scope:
         vars2train = get_variables_to_train(trainable_scopes=scope)
     train_op = slim.learning.create_train_op(loss,
                                              self.optimizer(),
                                              variables_to_train=vars2train,
                                              global_step=self.global_step,
                                              summarize_gradients=False)
     return train_op
コード例 #6
0
    def build_model(self, batch_queue, opt, scope, tower_id):
        vids_train, labels_train, ex_train = batch_queue.get_next()
        vids_train.set_shape((self.batch_size, ) +
                             self.pre_processor.out_shape)
        labels_train.set_shape((self.batch_size, ))
        print('vids_train : {}'.format(vids_train.get_shape().as_list()))

        tf.compat.v1.summary.histogram('vids_train', vids_train)
        tf.compat.v1.summary.histogram('labels_train', labels_train)
        tf.compat.v1.summary.image(
            'imgs/frames_train',
            montage_tf(tf.concat(tf.unstack(vids_train, axis=1), 0), 8,
                       self.batch_size),
            max_outputs=1)

        # Augment the training examples
        vids_train = self.pre_processor.augment_train(vids_train)

        # Create the model
        preds = self.model.model(vids_train, self.dataset.num_classes)

        tf.compat.v1.summary.image(
            'imgs/frames_train_augmented',
            montage_tf(tf.concat(tf.unstack(vids_train, axis=1), 0), 8,
                       self.batch_size),
            max_outputs=1)

        # Compute losses
        loss = self.model.loss(scope, preds,
                               self.dataset.format_labels(labels_train))

        # Handle dependencies with update_ops (batch-norm)
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            loss = control_flow_ops.with_dependencies([updates], loss)

        # Calculate the gradients for the batch of data on this tower.
        grads = opt.compute_gradients(
            loss, get_variables_to_train(self.train_scopes))

        self.summaries += tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.SUMMARIES, scope)
        return loss, grads, {}
コード例 #7
0
    def train_model(self, chpt_path=None):
        g = tf.Graph()
        with g.as_default():
            with tf.device('/cpu:0'):
                tf.compat.v1.random.set_random_seed(123)

                # Init global step
                self.global_step = tf.compat.v1.train.create_global_step()

                batch_queue = self.get_data_queue()
                opt = self.optimizer()

                # Calculate the gradients for each model tower.
                tower_grads = []
                loss = 0.

                with tf.compat.v1.variable_scope(
                        tf.compat.v1.get_variable_scope()):
                    for i in range(self.num_gpus):
                        with tf.device('/gpu:%d' % i):
                            with tf.name_scope('tower_{}'.format(i)) as scope:
                                loss_, grads_, layers_ = self.build_model(
                                    batch_queue, opt, scope, i)
                                loss += loss_ / self.num_gpus

                            tower_grads.append(grads_)

                grad = average_gradients(tower_grads)

                # Make summaries
                self.make_summaries(grad, layers_)

                # Apply the gradients to adjust the shared variables.
                print(
                    '========================================WD VARS==============================================='
                )
                wd_vars = get_variables_to_train(self.train_scopes)
                if self.excl_gamma_wd:
                    wd_vars = [v for v in wd_vars if 'gamma' not in v.op.name]
                if self.excl_beta_wd:
                    wd_vars = [v for v in wd_vars if 'beta' not in v.op.name]
                if self.excl_bias_wd:
                    wd_vars = [v for v in wd_vars if 'biases' not in v.op.name]
                print('WD variables: {}'.format([v.op.name for v in wd_vars]))
                print(
                    '=============================================================================================='
                )

                train_op = opt.apply_gradients(grad,
                                               global_step=self.global_step,
                                               decay_var_list=wd_vars)

                # Group all updates to into a single train op.
                train_op = control_flow_ops.with_dependencies([train_op], loss)

                # Create a saver.
                saver = tf.compat.v1.train.Saver(
                    tf.compat.v1.global_variables())
                init_fn = self.make_init_fn(chpt_path)

                # Build the summary operation from the last tower summaries.
                summary_op = tf.compat.v1.summary.merge(self.summaries)

                # Build an initialization operation to run below.
                init = tf.compat.v1.global_variables_initializer()

                # Start running operations on the Graph.
                sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
                    allow_soft_placement=True, log_device_placement=False),
                                            graph=g)
                sess.run(init)
                if init_fn:
                    init_fn(sess)

                summary_writer = tf.compat.v1.summary.FileWriter(
                    self.get_save_dir(), sess.graph)
                init_step = sess.run(self.global_step)
                print('Start training at step: {}'.format(init_step))
                for step in range(init_step, self.num_train_steps):

                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss])

                    duration = time.time() - start_time

                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'

                    if step % (self.num_train_steps // 2000) == 0:
                        num_examples_per_step = self.batch_size
                        examples_per_sec = num_examples_per_step / duration
                        sec_per_batch = duration
                        print(
                            '{}: step {}/{}, loss = {} ({} examples/sec; {} sec/batch)'
                            .format(datetime.now(), step, self.num_train_steps,
                                    loss_value, examples_per_sec,
                                    sec_per_batch))
                        sys.stdout.flush()

                    if step % (self.num_train_steps // 200) == 0:
                        print('Writing summaries...')
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)

                    # Save the model checkpoint periodically.
                    if step % (self.num_train_steps // 40) == 0 or (
                            step + 1) == self.num_train_steps:
                        checkpoint_path = os.path.join(self.get_save_dir(),
                                                       'model.ckpt')
                        print(
                            'Saving checkpoint to: {}'.format(checkpoint_path))
                        saver.save(sess, checkpoint_path, global_step=step)
コード例 #8
0
    def build_model(self, batch_queue, tower, opt, scope):
        """
            The main function where the bilevel approach is used
        """
        imgs_train, labels_train = batch_queue.get_next()

        tf.summary.histogram('labels', labels_train)

        # We split the training batches in the pre-defined splits (each containing the same label distribution)
        num_split = self.data_generator.batch_splits
        imgs_train_list = tf.split(imgs_train, num_split)
        labels_train_list = tf.split(labels_train, num_split)

        preds_list = []
        loss_list = []
        # Iterate over all the batch splits
        for i, (imgs,
                labels) in enumerate(zip(imgs_train_list, labels_train_list)):
            tf.summary.image('imgs/train',
                             montage_tf(imgs, 1, 8),
                             max_outputs=1)

            # Create the model
            reuse = True if (tower > 0 or i > 0) else None
            preds, layers = self.model.net(imgs,
                                           self.data_generator.num_classes,
                                           reuse=reuse)
            preds_list.append(preds)

            # Compute losses
            loss = self.model.loss(scope, preds,
                                   self.data_generator.format_labels(labels),
                                   tower)
            tf.get_variable_scope().reuse_variables()

            # Handle dependencies with update_ops (batch-norm)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if update_ops:
                updates = tf.group(*update_ops)
                loss = control_flow_ops.with_dependencies([updates], loss)

            # Store the loss on this split in the list
            loss_list.append(loss)

        # Calculate the gradients on all the batch splits.
        weights = get_variables_to_train(self.train_scopes)
        grads_list = [opt.compute_gradients(l, weights) for l in loss_list]

        # A dictionary with a list of gradients corresponding to the model variables
        grads_accum = {v: [] for (_, v) in grads_list[0]}

        # Flatten the gradients of each split
        grads_flat = [
            tf.concat([tf.reshape(g, (-1, 1)) for (g, v) in grad], axis=0)
            for grad in grads_list
        ]

        # Compute the mini-batch weights
        val_grad = grads_flat[0]
        w = [
            tf.divide(
                tf.reduce_sum(tf.multiply(val_grad, train_grad)),
                tf.reduce_sum(tf.multiply(train_grad, train_grad)) + self.mu)
            for train_grad in grads_flat[1:]
        ]

        # Multiply mini-batch gradients by l1 normalized weights
        w_l1norm = tf.reduce_sum(tf.abs(w))
        for i, grads in enumerate(grads_list[1:]):
            for g, v in grads:
                grads_accum[v].append(tf.multiply(g, w[i] / w_l1norm))
        tf.summary.histogram('w', tf.stack(w))

        # Apply weight-decay
        grads_wd = {
            v: self.model.weight_decay *
            v if v.op.name.endswith('weights') else 0.0
            for (_, v) in grads_list[0]
        }

        # Accumulate all the gradients per variable
        grads = [(tf.accumulate_n(grads_accum[v]) + grads_wd[v], v)
                 for (_, v) in grads_list[0]]

        self.summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
        return tf.reduce_mean(loss_list), grads, layers
コード例 #9
0
ファイル: demo_run_WSCI.py プロジェクト: hutt94/paper_code
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu_id)
    FLAGS.num_preprocessing_threads = 10
    FLAGS.max_epoch_num = 100
    FLAGS.train_split_name = 'train'
    FLAGS.test_split_name = 'test'
    FLAGS.model_name = 'inception_v3'

    FLAGS.dataset_dir = 'datasets'
    FLAGS.attr2class_file = os.path.join(FLAGS.dataset_dir, 'attr2class.txt')
    FLAGS.train_dir = 'output'
    FLAGS.checkpoint_path = os.path.join('pretrained_models',
                                         '%s.ckpt' % (FLAGS.model_name))
    log_file_path = os.path.join(FLAGS.train_dir, 'log')

    if not os.path.isdir(FLAGS.train_dir):
        os.makedirs(FLAGS.train_dir)

    FLAGS.checkpoint_exclude_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits,VAE'
    FLAGS.trainable_scopes = 'VAE'  # if learning all parameters including CNN, set FLAGS.trainable_scopes=None
    FLAGS.checkpoint_exclude_keywords = None
    FLAGS.batch_size = 64

    with tf.Graph().as_default():
        # load dataset
        dataset = dataset_factory.get_dataset(FLAGS.train_split_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.test_split_name,
                                                   FLAGS.dataset_dir)
        num_batches = int(
            math.ceil(dataset.num_samples / float(FLAGS.batch_size)))
        num_test_batches = int(
            math.ceil(test_dataset.num_samples / float(FLAGS.batch_size)))
        train_image_size = nets_factory.get_network_fn(
            FLAGS.model_name, dataset.num_classes).default_image_size
        images, labels = load_batch(FLAGS,
                                    dataset,
                                    train_image_size,
                                    train_image_size,
                                    is_training=True)
        test_images, test_labels = load_batch(FLAGS,
                                              test_dataset,
                                              train_image_size,
                                              train_image_size,
                                              is_training=False)

        # load class attributes
        attr2class = np.loadtxt(FLAGS.attr2class_file, np.float32)

        # build networks
        train_batch_loss, train_summary = WSCI_network('train',
                                                       FLAGS.model_name,
                                                       images, labels,
                                                       attr2class, False,
                                                       FLAGS.beta)
        test_correct_arr = WSCI_network('test', FLAGS.model_name, test_images,
                                        test_labels, attr2class, True)

        # optimizer
        global_step = slim.create_global_step()
        config_lr = configure_learning_rate(FLAGS, dataset.num_samples,
                                            global_step)
        optimizer = configure_optimizer(learning_rate=config_lr)
        variables_to_train = get_variables_to_train(FLAGS)
        grads_and_vars = optimizer.compute_gradients(train_batch_loss,
                                                     variables_to_train)
        update_ops = [
            optimizer.apply_gradients(grads_and_vars, global_step=global_step)
        ]
        # update moving_mean and moving_variance for batch normalization
        # update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(train_batch_loss)

        # main code
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            # initialization
            sess.run(tf.global_variables_initializer())

            with slim.queues.QueueRunners(sess):
                # initialization
                iepoch = 0
                init_fn = get_init_fn(FLAGS.checkpoint_path,
                                      FLAGS.checkpoint_exclude_scopes,
                                      FLAGS.checkpoint_exclude_keywords)
                init_fn(sess)

                while iepoch < FLAGS.max_epoch_num:
                    #training
                    for ibatch in range(num_batches):
                        print 'iepoch %d: train %d/%d' % (iepoch, ibatch,
                                                          num_batches)
                        sess.run([train_op, train_summary])

                    # test
                    correct_num = 0
                    for ibatch in range(num_test_batches):
                        print 'iepoch %d: test %d/%d' % (iepoch, ibatch,
                                                         num_test_batches)
                        correct_arr = sess.run(test_correct_arr)
                        correct_num += np.sum(correct_arr)
                    test_acc = float(correct_num) / (num_test_batches *
                                                     FLAGS.batch_size)

                    fid = open(log_file_path, 'a+')
                    fid.write('%d %f\n' % (iepoch, test_acc))
                    fid.close()
コード例 #10
0
    def build_model(self, batch_queue, opt_g, opt_d, opt_c, scope, tower_id):
        # Define scopes for LCI components
        lci_enc_scope = 'encoder_ae_{}'.format(tower_id)
        lci_dec_scope = 'decoder_ae_{}'.format(tower_id)
        lci_disc_scope = 'discriminator_{}'.format(tower_id)

        # Load batch of images
        imgs_train, _ = batch_queue.get_next()
        imgs_train.set_shape([
            self.model.batch_size,
        ] + self.model.im_shape)

        # Create warped images
        w_mag = self.model.im_shape[0] / self.warp_factor
        p_x = tf.random_uniform([self.model.batch_size, self.n_warp_points],
                                minval=0,
                                maxval=self.model.im_shape[0])
        p_y = tf.random_uniform([self.model.batch_size, self.n_warp_points],
                                minval=0,
                                maxval=self.model.im_shape[1])
        c_points_src = tf.stack([p_x, p_y], axis=-1)
        c_points_dest = c_points_src + tf.random_uniform(
            c_points_src.get_shape(), -w_mag, w_mag)
        imgs_warp, _ = contrib.image.sparse_image_warp(imgs_train,
                                                       c_points_src,
                                                       c_points_dest)
        tf.summary.image('imgs/img_warp',
                         montage_tf(imgs_warp, 4, 8),
                         max_outputs=1)

        # Perform LCI
        patch_lci, patch_ae, mask_erase, mask_orig, crop_img, imgs_lci, imgs_patchae =\
            self.model.lci(imgs_train, enc_scope=lci_enc_scope, dec_scope=lci_dec_scope)
        tf.summary.image('imgs/real_imgs',
                         montage_tf(imgs_patchae, 4, 8),
                         max_outputs=1)
        tf.summary.image('imgs/fake_imgs',
                         montage_tf(imgs_lci, 4, 8),
                         max_outputs=1)

        # Build untransformed images (half original, half with autoencoded patches)
        imgs_nt_1, _ = tf.split(imgs_train, 2)
        _, imgs_nt_2 = tf.split(imgs_patchae, 2)
        imgs_nt = tf.concat([imgs_nt_1, imgs_nt_2], 0)

        # Perform additional augmentations to make detection of LCI harder
        imgs_lci = random_crop_rot(imgs_lci, self.crop_sz)
        imgs_nt = random_crop_rot(imgs_nt, self.crop_sz)
        imgs_warp = random_crop_rot(imgs_warp, self.crop_sz)

        # Generate the rotated images
        imgs_rot, _ = all_rot(imgs_nt)

        # Patch disciminator for LCI
        preds_fake = self.model.patch_disc(patch_lci,
                                           update_collection="NO_OPS",
                                           disc_scope=lci_disc_scope)
        preds_real = self.model.patch_disc(crop_img,
                                           update_collection=None,
                                           disc_scope=lci_disc_scope)

        # The transformation classifier
        class_in = tf.concat([imgs_nt, imgs_lci, imgs_rot, imgs_warp], 0)
        preds = self.model.net(class_in)

        # Build SSL labels
        labels = tf.concat([
            tf.zeros((self.model.batch_size, ), dtype=tf.int32),
            tf.ones((self.model.batch_size, ), dtype=tf.int32), 2 * tf.ones(
                (self.model.batch_size, ), dtype=tf.int32), 3 * tf.ones(
                    (self.model.batch_size, ), dtype=tf.int32), 4 * tf.ones(
                        (self.model.batch_size, ), dtype=tf.int32),
            5 * tf.ones((self.model.batch_size, ), dtype=tf.int32)
        ], 0)
        labels = tf.one_hot(labels, 6)

        # Compute losses
        loss_c = self.model.loss_ssl(preds, labels)
        loss_disc_lci = self.model.discriminator_loss(preds_fake, preds_real)
        loss_ae_lci = self.model.inpainter_loss(preds_fake, crop_img,
                                                patch_lci, mask_erase,
                                                patch_ae, mask_orig)
        loss_ae_lci -= self.model.loss_lci_adv(preds, labels)

        # Handle dependencies with update_ops (batch-norm)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            loss_c = control_flow_ops.with_dependencies([updates], loss_c)

        # Calculate the gradients for the batch of data on this tower.
        grads_d = opt_d.compute_gradients(
            loss_disc_lci, get_variables_to_train(lci_disc_scope))
        grads_g = opt_g.compute_gradients(
            loss_ae_lci,
            get_variables_to_train('{},{}'.format(lci_enc_scope,
                                                  lci_dec_scope)))
        grads_c = opt_c.compute_gradients(
            loss_c, get_variables_to_train(self.train_scopes, print_vars=True))

        # Create some summaries
        if self.weight_summary:
            with tf.variable_scope('features', reuse=True):
                weights_disc_1 = slim.variable('conv_1/weights')
            tf.summary.image('weights/conv_1',
                             weights_montage(weights_disc_1, 6, 16),
                             max_outputs=1)
        tf.summary.scalar('lr_decay_mult', self.lr_decay_mult())
        self.summaries += tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

        return loss_ae_lci, loss_disc_lci, loss_c, grads_g, grads_d, grads_c, {}
コード例 #11
0
    def train_model(self, chpt_path=None):
        if chpt_path:
            print('Restoring from: {}'.format(chpt_path))
        g = tf.Graph()
        with g.as_default():
            with tf.device('/cpu:0'):
                # Init global step
                self.global_step = tf.train.create_global_step()

                # Init data
                batch_queue = self.get_data_queue()

                # Optimizer for the classifier
                opt_c = self.optimizer_class()

                # Calculate the gradients for each model tower.
                train_ops_g = []
                train_ops_d = []
                tower_grads_c = []
                loss_c = 0.
                loss_g = 0.
                loss_d = 0.

                with tf.variable_scope(tf.get_variable_scope()):
                    for i in range(self.num_gpus):
                        with tf.device('/gpu:%d' % i):
                            # LCI parameters are not shared across GPUs
                            opt_g = self.optimizer('g')
                            opt_d = self.optimizer('d')

                            with tf.name_scope('tower_{}'.format(i)) as scope:
                                l_g, l_d, l_c, grad_g, grad_d, grad_c, layers_d = \
                                    self.build_model(batch_queue, opt_g, opt_d, opt_c, scope, i)
                                loss_c += l_c
                                loss_g += l_g
                                loss_d += l_d

                            # Training ops for LCI
                            train_op_g = opt_g.apply_gradients(grad_g)
                            train_op_d = opt_d.apply_gradients(grad_d)
                            train_ops_d.append(train_op_d)
                            train_ops_g.append(train_op_g)

                            # Aggregate gradients for the transformation classifier
                            tower_grads_c.append(grad_c)

                # Average gradients for classifier from all GPUs
                grad_c = average_gradients(tower_grads_c)

                # Make summaries
                self.make_summaries(grad_d + grad_g + grad_c, layers_d)

                # Apply the gradients to adjust the shared variables.
                wd_vars = get_variables_to_train(self.train_scopes)
                if self.excl_gamma_wd:
                    wd_vars = [v for v in wd_vars if 'gamma' not in v.op.name]
                if self.excl_beta_wd:
                    wd_vars = [v for v in wd_vars if 'beta' not in v.op.name]
                print('WD variables: {}'.format([v.op.name for v in wd_vars]))
                train_op_c = opt_c.apply_gradients(
                    grad_c,
                    global_step=self.global_step,
                    decay_var_list=wd_vars)

                # Group all updates to into a single train op.
                train_op = control_flow_ops.with_dependencies(
                    [train_op_c] + train_ops_d + train_ops_g,
                    loss_d + loss_g + loss_c)

                # Create a saver.
                saver = tf.train.Saver(tf.global_variables())
                init_fn = self.make_init_fn(chpt_path)

                # Build the summary operation from the last tower summaries.
                summary_op = tf.summary.merge(self.summaries)

                # Build an initialization operation to run below.
                init = tf.global_variables_initializer()

                # Start running operations on the Graph.
                sess = tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True, log_device_placement=False),
                                  graph=g)
                sess.run(init)
                prev_ckpt = get_checkpoint_path(self.get_save_dir())
                if prev_ckpt:
                    print('Restoring from previous checkpoint: {}'.format(
                        prev_ckpt))
                    saver.restore(sess, prev_ckpt)
                elif init_fn:
                    init_fn(sess)

                summary_writer = tf.summary.FileWriter(self.get_save_dir(),
                                                       sess.graph)
                init_step = sess.run(self.global_step)
                print('Start training at step: {}'.format(init_step))
                for step in range(init_step, self.num_train_steps):

                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss_c])
                    duration = time.time() - start_time

                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'

                    if step % (self.num_train_steps // 2000) == 0:
                        num_examples_per_step = self.model.batch_size
                        examples_per_sec = num_examples_per_step / duration
                        sec_per_batch = duration
                        print(
                            '{}: step {}/{}, loss = {} ({} examples/sec; {} sec/batch)'
                            .format(datetime.now(), step, self.num_train_steps,
                                    loss_value, examples_per_sec,
                                    sec_per_batch))
                        sys.stdout.flush()

                    if step % (self.num_train_steps // 200) == 0:
                        print('Writing summaries...')
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)

                    # Save the model checkpoint periodically.
                    if step % (self.num_train_steps // 40) == 0 or (
                            step + 1) == self.num_train_steps:
                        checkpoint_path = os.path.join(self.get_save_dir(),
                                                       'model.ckpt')
                        print(
                            'Saving checkpoint to: {}'.format(checkpoint_path))
                        saver.save(sess, checkpoint_path, global_step=step)
コード例 #12
0
    def build_model(self, batch_queue, opt, scope, tower_id):
        bs = self.batch_size
        vids_train, skip_label, example = batch_queue.get_next()
        vids_train.set_shape((bs, ) + self.pre_processor.out_shape)
        print('vids_train : {}'.format(vids_train.get_shape().as_list()))
        tf.compat.v1.summary.histogram('skip_label', skip_label)

        # Perform common augmentations
        vids_train = self.pre_processor.augment_train(vids_train)
        vids_transformed = tf.unstack(vids_train, axis=1)
        transforms = ['orig'] + self.pre_processor.transforms

        # Make summaries
        for v, transform in zip(vids_transformed, transforms):
            tf.compat.v1.summary.image('imgs/{}'.format(transform),
                                       montage_tf(
                                           tf.concat(tf.unstack(v, axis=1), 0),
                                           16, bs),
                                       max_outputs=1)

        # Construct net input
        num_transform_classes = len(vids_transformed)
        vids_train = tf.concat(vids_transformed, 0)
        labels_train = tf.concat([
            i * tf.ones((bs, ), dtype=tf.int32)
            for i in range(num_transform_classes)
        ], 0)
        labels_train = tf.one_hot(labels_train, num_transform_classes)

        num_skip_classes = self.pre_processor.n_speeds
        skip_label = tf.one_hot(skip_label, num_skip_classes)
        skip_label = tf.concat([skip_label, skip_label], 0)

        if self.skip_pred:
            num_classes = num_transform_classes + num_skip_classes
        else:
            num_classes = num_transform_classes

        # Create the model
        preds = self.model.model(vids_train, num_classes)

        if self.skip_pred:
            preds_transform, preds_skip = tf.split(
                preds, [num_transform_classes, num_skip_classes], -1)
            preds_skip = preds_skip[:2 * bs]
            tf.compat.v1.summary.scalar(
                'accuracy/skip_pred',
                slim.metrics.accuracy(tf.argmax(preds_skip, 1),
                                      tf.argmax(skip_label, 1)))
        else:
            preds_transform = preds

        # Compute accuracy
        predictions_transform = tf.argmax(preds_transform, 1)
        labels_transform = tf.argmax(labels_train, 1)
        tf.compat.v1.summary.scalar(
            'accuracy/all_transforms',
            slim.metrics.accuracy(predictions_transform, labels_transform))

        for p, l, t in zip(tf.split(predictions_transform, len(transforms)),
                           tf.split(labels_transform, len(transforms)),
                           transforms):
            tf.compat.v1.summary.scalar('accuracy/{}'.format(t),
                                        slim.metrics.accuracy(p, l))

        # Compute losses
        loss_transform = self.model.loss(scope,
                                         preds_transform,
                                         labels_train,
                                         summary=False)
        tf.compat.v1.summary.scalar('losses/loss_transform', loss_transform)

        if self.skip_pred:
            loss_skip = self.model.loss(scope,
                                        preds_skip,
                                        skip_label,
                                        summary=False)
            tf.compat.v1.summary.scalar('losses/loss_skip', loss_skip)
            loss = loss_transform + loss_skip
        else:
            loss = loss_transform

        # Handle dependencies with update_ops (batch-norm)
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)
        if update_ops:
            updates = tf.group(*update_ops)
            loss = control_flow_ops.with_dependencies([updates], loss)

        # Calculate the gradients for the batch of data on this tower.
        grads = opt.compute_gradients(
            loss, get_variables_to_train(self.train_scopes))

        self.summaries += tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.SUMMARIES, scope)
        return loss, grads, {}