Exemplo n.º 1
0
def save_gradient_pic(D, fixed_generated_images, iteration, opts):
    pic = utils.to_var(fixed_generated_images)
    pic.requires_grad_(True)
    loss = mse_loss(D(pic), 1)
    path = os.path.join(opts.sample_dir,
                        'gradients-{:06d}.png'.format(iteration))
    loss.backward()
    gradients = utils.to_data(pic.grad)

    # Only red channel
    gradients[:, :] = np.sqrt(np.sum(gradients**2, axis=1, keepdims=True))
    grid = create_image_grid(gradients)
    scipy.misc.imsave(path, grid)
Exemplo n.º 2
0
    def evaluate_none_task(self, model, teacher_model, valid_loader):
        if teacher_model is None:
            return [('no-metric', float("-inf"))]

        total_loss = 0
        total_steps = 0
        total_samples = 0

        total_spent_time = 0.0
        for _step, batch in enumerate(valid_loader):
            batch = {
                key: val.cuda() if isinstance(val, torch.Tensor) else val
                for key, val in batch.items()
            }

            infer_start_time = time.time()
            with torch.no_grad():
                student_outputs = model(batch)
            infer_end_time = time.time()
            total_spent_time += infer_end_time - infer_start_time

            with torch.no_grad():
                teacher_outputs = teacher_model(batch)
                student_hidn = student_outputs["hidden"][-1]
                teacher_hidn = teacher_outputs["hidden"][-1]
                tmp_loss = losses.mse_loss(student_hidn, teacher_hidn)

            total_loss += tmp_loss.mean().item()
            total_steps += 1
            total_samples += valid_loader.batch_size
            if (_step + 1) % 100 == 0:
                logger.info("Eval: %d/%d steps finished" %
                            (_step + 1, len(valid_loader.dataset) //
                             valid_loader.batch_size))

        logger.info("Inference time = {:.2f}s, [{:.4f} ms / sample] ".format(
            total_spent_time, total_spent_time * 1000 / total_samples))

        eval_loss = total_loss / total_steps
        logger.info("Eval loss: {}".format(eval_loss))

        return [("last_layer_mse", -eval_loss)]
Exemplo n.º 3
0
def training_loop(dataloader_X, dataloader_Y, test_dataloader_X,
                  test_dataloader_Y, opts):
    """Runs the training loop.
        * Saves checkpoint every opts.checkpoint_every iterations
        * Saves generated samples every opts.sample_every iterations
    """

    # Create generators and discriminators
    if opts.load:
        G_XtoY, G_YtoX, D_X, D_Y = load_checkpoint(opts)
    else:
        G_XtoY, G_YtoX, D_X, D_Y = create_model(opts)

    g_params = list(G_XtoY.parameters()) + list(
        G_YtoX.parameters())  # Get generator parameters
    d_params = list(D_X.parameters()) + list(
        D_Y.parameters())  # Get discriminator parameters

    # Create optimizers for the generators and discriminators
    g_optimizer = optim.Adam(g_params, opts.lr, [opts.beta1, opts.beta2])
    d_optimizer = optim.Adam(d_params, opts.lr, [opts.beta1, opts.beta2])

    iter_X = iter(dataloader_X)
    iter_Y = iter(dataloader_Y)

    test_iter_X = iter(test_dataloader_X)
    test_iter_Y = iter(test_dataloader_Y)

    # Get some fixed data from domains X and Y for sampling. These are images that are held
    # constant throughout training, that allow us to inspect the model's performance.
    fixed_X = utils.to_var(test_iter_X.next()[0])
    fixed_Y = utils.to_var(test_iter_Y.next()[0])

    iter_per_epoch = min(len(iter_X), len(iter_Y))

    for iteration in range(1, opts.train_iters + 1):

        # Reset data_iter for each epoch
        if iteration % iter_per_epoch == 0:
            iter_X = iter(dataloader_X)
            iter_Y = iter(dataloader_Y)

        images_X, labels_X = iter_X.next()
        images_X, labels_X = utils.to_var(images_X), utils.to_var(
            labels_X).long().squeeze()

        images_Y, labels_Y = iter_Y.next()
        images_Y, labels_Y = utils.to_var(images_Y), utils.to_var(
            labels_Y).long().squeeze()

        # ============================================
        #            TRAIN THE DISCRIMINATORS
        # ============================================

        #########################################
        ##             FILL THIS IN            ##
        #########################################

        # Train with real images
        d_optimizer.zero_grad()

        # 1. Compute the discriminator losses on real images
        D_X_loss = (D_X(images_X) - 1).pow(2).sum() / len(images_X)
        D_Y_loss = (D_Y(images_Y) - 1).pow(2).sum() / len(images_Y)

        d_real_loss = D_X_loss + D_Y_loss
        d_real_loss.backward()
        d_optimizer.step()

        # Train with fake images
        d_optimizer.zero_grad()

        # 2. Generate fake images that look like domain X based on real images in domain Y
        fake_X = G_YtoX(images_Y)

        # 3. Compute the loss for D_X
        D_X_loss = mse_loss(D_X(fake_X), 0)

        # 4. Generate fake images that look like domain Y based on real images in domain X
        fake_Y = G_XtoY(images_X)

        # 5. Compute the loss for D_Y
        D_Y_loss = mse_loss(D_Y(fake_Y), 0)

        d_fake_loss = D_X_loss + D_Y_loss
        d_fake_loss.backward()
        d_optimizer.step()

        # =========================================
        #            TRAIN THE GENERATORS
        # =========================================

        #########################################
        ##    FILL THIS IN: Y--X-->Y CYCLE     ##
        #########################################
        g_optimizer.zero_grad()

        # 1. Generate fake images that look like domain X based on real images in domain Y
        fake_X = G_YtoX(images_Y)

        # 2. Compute the generator loss based on domain X
        g_loss = mse_loss(D_X(fake_X), 1)

        if opts.use_cycle_consistency_loss:
            reconstructed_Y = G_XtoY(fake_X)
            # 3. Compute the cycle consistency loss (the reconstruction loss)
            cycle_consistency_loss = mse_loss(images_Y, reconstructed_Y)
            g_loss += cycle_consistency_loss

        g_loss.backward()
        g_optimizer.step()

        #########################################
        ##    FILL THIS IN: X--Y-->X CYCLE     ##
        #########################################

        g_optimizer.zero_grad()

        # 1. Generate fake images that look like domain Y based on real images in domain X
        fake_Y = G_XtoY(images_X)

        # 2. Compute the generator loss based on domain Y
        g_loss = mse_loss(D_Y(fake_Y), 1)

        if opts.use_cycle_consistency_loss:
            reconstructed_X = G_YtoX(fake_Y)
            # 3. Compute the cycle consistency loss (the reconstruction loss)
            cycle_consistency_loss = mse_loss(images_X, reconstructed_X)
            g_loss += cycle_consistency_loss

        g_loss.backward()
        g_optimizer.step()

        # Print the log info
        if iteration % opts.log_step == 0:
            print(
                'Iteration [{:5d}/{:5d}] | d_real_loss: {:6.4f} | d_Y_loss: {:6.4f} | d_X_loss: {:6.4f} | '
                'd_fake_loss: {:6.4f} | g_loss: {:6.4f}'.format(
                    iteration, opts.train_iters, d_real_loss.data[0],
                    D_Y_loss.data[0], D_X_loss.data[0], d_fake_loss.data[0],
                    g_loss.data[0]))

        # Save the generated samples
        if iteration % opts.sample_every == 0:
            save_samples(iteration, fixed_Y, fixed_X, G_YtoX, G_XtoY, opts)

        # Save the model parameters
        if iteration % opts.checkpoint_every == 0:
            checkpoint(iteration, G_XtoY, G_YtoX, D_X, D_Y, opts)
Exemplo n.º 4
0
def train_fun(dataset, opts):
    # dataset and iterator
    dataset_train, dataset_val = dataset.get_dataset(opts)
    iterator = tf.data.Iterator.from_structure(dataset_train.output_types,
                                               dataset_train.output_shapes)
    if opts.train_like_test:
        volume, label = iterator.get_next()
        inputs = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, 1 + opts.temporal * opts.nJoint])
        labels = tf.placeholder(tf.float32,
                                shape=[None, None, None, None, opts.nJoint])
    else:
        inputs, labels = iterator.get_next()

    # network
    outputs, training = get_network(inputs, opts)

    # loss
    loss, mean_update_ops = mse_loss(outputs, labels, opts)

    # summary
    writer_train = tf.summary.FileWriter(
        os.path.join(opts.output_path, opts.name, 'logs', 'train'),
        tf.get_default_graph())
    writer_val = tf.summary.FileWriter(
        os.path.join(opts.output_path, opts.name, 'logs', 'val'))
    summary_op = tf.summary.merge_all()

    # varlist
    name_list = [
        ns[0] for ns in tf.train.list_variables(
            os.path.join(opts.output_path, opts.use_pretrain,
                         'pretrainmodel.ckpt'))
    ] if opts.use_pretrain != '' else []
    pretrain_list = [
        v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=opts.network)
        if v.name[:-2] in name_list
    ]
    newtrain_list = [
        v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=opts.network)
        if v.name[:-2] not in name_list
    ]
    print('pretrain var: %d, newtrain var: %d' %
          (len(pretrain_list), len(newtrain_list)))

    # optimizer
    optimizer = Optimizer(opts, pretrain_list, newtrain_list,
                          dataset_train.length)
    train_op = optimizer.get_train_op(loss)
    my_update_op = tf.group(mean_update_ops)

    # save and load
    saver = tf.train.Saver(var_list=newtrain_list + pretrain_list)
    if opts.use_pretrain != '':
        saver_pretrain = tf.train.Saver(var_list=pretrain_list)

    # main loop
    with tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                          allow_soft_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())
        if opts.use_pretrain != '':
            saver_pretrain.restore(
                sess,
                os.path.join(opts.output_path, opts.use_pretrain + 'pretrain',
                             'pretrainmodel.ckpt'))
        if opts.epoch_continue > 0:
            saver.restore(
                sess,
                os.path.join(opts.output_path, opts.use_continue,
                             'model%d.ckpt' % opts.epoch_continue))
        print('training loop start')
        start_train = time.time()
        for epoch in range(opts.epoch_continue + 1, opts.epochs + 1):
            print('epoch: %d' % epoch)
            start_ep = time.time()
            # train
            print('training')
            sess.run(iterator.make_initializer(dataset_train))
            sess.run(tf.local_variables_initializer())
            while True:
                try:
                    if opts.train_like_test:
                        v, l = sess.run([volume, label])
                        if random.random() < opts.train_like_test:
                            l_p = sess.run(outputs[-1],
                                           feed_dict={
                                               training: False,
                                               inputs: v
                                           })
                            v = np.concatenate([v[:, :, :, :, :1], l_p],
                                               axis=-1)
                        summary_train, _ = sess.run([summary_op, train_op],
                                                    feed_dict={
                                                        training: True,
                                                        inputs: v,
                                                        labels: l
                                                    })
                    else:
                        summary_train, _ = sess.run([summary_op, train_op],
                                                    feed_dict={training: True})
                except tf.errors.OutOfRangeError:
                    writer_train.add_summary(summary_train, epoch)
                    break
            print('step: %d' % optimizer.get_global_step(sess))
            # validation
            print('validation')
            sess.run(iterator.make_initializer(dataset_val))
            sess.run(tf.local_variables_initializer())
            while True:
                try:
                    if opts.train_like_test:
                        v, l = sess.run([volume, label])
                        if random.random() < opts.train_like_test:
                            l_p = sess.run(outputs[-1],
                                           feed_dict={
                                               training: False,
                                               inputs: v
                                           })
                            v = np.concatenate([v[:, :, :, :, :1], l_p],
                                               axis=-1)
                        summary_val, _ = sess.run([summary_op, my_update_op],
                                                  feed_dict={
                                                      training: False,
                                                      inputs: v,
                                                      labels: l
                                                  })
                    else:
                        summary_val, _ = sess.run([summary_op, my_update_op],
                                                  feed_dict={training: False})
                except tf.errors.OutOfRangeError:
                    writer_val.add_summary(summary_val, epoch)
                    break
            # save model
            if epoch % opts.save_freq == 0 or epoch == opts.epochs:
                print('save model')
                saver.save(
                    sess,
                    os.path.join(opts.output_path, opts.name,
                                 'model%d.ckpt' % epoch))
            print("epoch end, elapsed time: %ds, total time: %ds" %
                  (time.time() - start_ep, time.time() - start_train))
        print('training loop end')
        writer_train.close()
        writer_val.close()
    opts.run = 'test'
Exemplo n.º 5
0
def main(result_dir, resume_file, resume_epoch, nepochs, f, input_type,
         model_type, num_iters, admm_filters, admm_strides, admm_kernels, lr,
         val_only, train_size, val_size, dataset, redraw_subset, batch_size,
         repeat, admm_tv_loss, no_vis_output, val_output_every, png_output,
         png_output_dir):
    def clip(rgb):
        return np.maximum(np.minimum(rgb, 255), 0)

    if dataset == 'kitti':

        trainfiles = ld.get_train_paths('/dataset/kitti-depth/tfrecords/train')
        num_train_examples = ld.count_records(trainfiles)

        print('Got {} training files with {} records'.format(
            len(trainfiles), num_train_examples))

        valfiles = ld.get_train_paths('/dataset/kitti-depth/tfrecords/val')
        num_val_examples = ld.count_records(valfiles)
        print('Got {} validation files with {} records'.format(
            len(valfiles), num_val_examples))
        make_datasets = lambda mkinpts, bs: ld.make_kitti_datasets(
            mkinpts, trainfiles, valfiles, bs, repeat=repeat)
    elif dataset == 'kitti_test_selection':
        test_root = '/dataset/kitti-depth/depth_selection/test_depth_completion_anonymous'
        num_train_examples = len(
            ld.get_train_paths(test_root + '/velodyne_raw', suffix='png'))
        num_val_examples = num_train_examples
        make_datasets = lambda mkinpts, bs: ld.make_selection_datasets(
            mkinpts, test_root)
    elif dataset == 'kitti_val_selection':
        val_root = '/dataset/kitti-depth/depth_selection/val_selection_cropped'
        num_train_examples = len(
            ld.get_train_paths(val_root + '/velodyne_raw', suffix='png'))
        num_val_examples = num_train_examples
        make_datasets = lambda mkinpts, bs: ld.make_selection_datasets(
            mkinpts, val_root)

    print('Got {} training examples'.format(num_train_examples))
    print('Got {} validation examples'.format(num_val_examples))

    if train_size < 0:
        train_size = num_train_examples
    if val_size < 0:
        val_size = num_val_examples

    if input_type == 'raw':

        def make_raw_inputs(urgb, m, g, mraw, raw, s):
            m1 = mraw
            return urgb, m1, m1 * raw, m, g, s

        make_inputs = make_raw_inputs
    elif input_type == 'raw_frac':

        def make_raw_frac_inputs(urgb, m, g, mraw, raw, s):
            m1, d1 = sparsify(raw, mraw, f)
            return urgb, m1, d1, m, g, s

        make_inputs = make_raw_frac_inputs

    if model_type == 'admm':

        def build_admm(m1, d1, m2, d2, is_training):
            return admm.make_admm_pnp(
                m1,
                d1,
                m2,
                d2,  # Adapt to PnP-Depth
                tv_loss=admm_tv_loss,
                num_iters=num_iters,
                filters=admm_filters,
                strides=admm_strides,
                kernels=admm_kernels)

        build_model = build_admm
    elif model_type == 'cnn_deep':
        build_model = lambda m1, d1, m2, d2, is_training: build_net18(
            m1, d1, m2, d2, is_training)
    elif model_type == 'sparse_cnn':
        build_model = lambda m1, d1, m2, d2, is_training: make_sparse_cnn(
            m1, d1, m2, d2)

    train_log = os.path.join(result_dir, 'train_log.txt')
    train_errors = ErrorLogger([
        'rmse',
        'grmse',
        'mae',
        'gmae',
        'mre',
        'del_1',
        'del_2',
        'del_3',
    ], [(8, 5), (8, 5), (8, 5), (8, 5), (8, 5), (5, 2), (5, 2), (5, 2)],
                               train_log)
    val_log = os.path.join(result_dir, 'val_log.txt')
    val_errors = ErrorLogger([
        'rmse',
        'grmse',
        'mae',
        'gmae',
        'mre',
        'del_1',
        'del_2',
        'del_3',
    ], [(8, 5), (8, 5), (8, 5), (8, 5), (8, 5), (5, 2), (5, 2), (5, 2)],
                             val_log)

    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 1.0
    config.gpu_options.allow_growth = True
    with tf.Graph().as_default(), tf.Session(config=config) as sess:

        train_dataset, val_dataset, take_pl = make_datasets(
            make_inputs, batch_size)
        print(train_dataset.output_shapes)
        iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
                                                   train_dataset.output_shapes)
        rgb_t, m1_t, d1_t, ground_mask, ground, s_t = iterator.get_next()

        train_data_init_op = iterator.make_initializer(train_dataset)
        val_data_init_op = iterator.make_initializer(val_dataset)

        is_training = tf.placeholder(tf.bool, name='is_training')
        output, loss, monitor, summary, model_train_op = build_model(
            m1_t, d1_t, ground_mask, ground, is_training)

        mse_t = losses.mse_loss(output, ground, ground_mask)
        mae_t = losses.mae_loss(output, ground, ground_mask)
        mre_t = losses.mre_loss(output, ground, ground_mask)
        rmse_t = losses.rmse_loss(output, ground, ground_mask)
        gmae_t = losses.mae_loss(output, ground, m1_t * ground_mask)
        grmse_t = losses.rmse_loss(output, ground, m1_t * ground_mask)
        del_1_t, del_2_t, del_3_t = losses.deltas(output, ground, ground_mask,
                                                  1.01)

        errors_t = {
            'rmse': rmse_t,
            'mae': mae_t,
            'mre': mre_t,
            'del_1': del_1_t,
            'del_2': del_2_t,
            'del_3': del_3_t,
            'grmse': grmse_t,
            'gmae': gmae_t
        }

        optimizer = tf.train.AdamOptimizer(learning_rate=lr)

        if model_train_op is not None:
            train_op = model_train_op
        else:
            extra_train_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(extra_train_op):
                train_op = optimizer.minimize(loss)

        saver = tf.train.Saver(max_to_keep=nepochs + 1)
        sess.run(tf.global_variables_initializer())
        if resume_file:
            print('Restoring from {}'.format(resume_file))
            saver.restore(sess, resume_file)

        best_rmse = float('inf')
        best_epoch = -1

        train_take = ld.make_take(num_train_examples, train_size)
        val_take = ld.make_take(num_val_examples, val_size)

        num_epochs = nepochs
        if val_only:
            num_epochs = 1

        for i in range(resume_epoch, num_epochs):
            if not val_only:
                num_batches = train_size // batch_size
                batchnum = 1

                if redraw_subset:
                    print('Redrawing Subset')
                    train_take = ld.make_take(num_train_examples, train_size)
                train_errors.clear()
                sess.run(train_data_init_op, feed_dict={take_pl: train_take})
                while True:
                    try:
                        start = time.time()
                        (err, pred, mg, g, rgb, m1, d1, m, s,
                         _) = sess.run([
                             errors_t, output, ground_mask, ground, rgb_t,
                             m1_t, d1_t, monitor, summary, train_op
                         ],
                                       feed_dict={is_training: True})
                        print('{}s to run'.format(time.time() - start))
                        train_errors.update(err)
                        print('{} in input, {} in ground truth'.format(
                            np.mean(np.sum(m1 > 0, axis=(1, 2, 3))),
                            np.mean(np.sum(mg > 0, axis=(1, 2, 3)))))
                        print('Epoch {}, Batch {}/{} {}'.format(
                            i, batchnum, num_batches,
                            train_errors.update_log_string(err)))
                        for key, value in m.items():
                            print('{}: {}'.format(key, value))
                        if batchnum % 500 == 0:
                            filename = 'train_output{}.pickle'.format(batchnum)
                            with open(os.path.join(result_dir, filename),
                                      'wb') as f:
                                pickle.dump(
                                    {
                                        'rgb': clip(rgb[0, :, :, :]),
                                        'd1': m1[0, :, :, :] * d1[0, :, :, :],
                                        'm0': s['m'][0] if 'm' in s else None,
                                        'ground': g[0, :, :, :],
                                        'pred': pred[0, :, :, :],
                                        'summary': s
                                    }, f)
                        batchnum += 1

                    except tf.errors.OutOfRangeError:
                        break
                train_errors.log()
                with open(os.path.join(result_dir, 'summary.pickle'),
                          'wb') as f:
                    pickle.dump(s, f)
                print('Done epoch {}, RMSE = {}'.format(
                    i, train_errors.get('rmse')))
                save_path = saver.save(
                    sess, os.path.join(result_dir,
                                       '{:02}-model.ckpt'.format(i)))
                print('Model saved in {}'.format(save_path))

            num_batches = val_size
            batchnum = 1

            val_errors.clear()
            sess.run(val_data_init_op, feed_dict={take_pl: val_take})
            best_batch = float('inf')
            worst_batch = 0
            rmses = {}
            i = 0

            while True:
                try:
                    start = time.time()
                    (err, pred, g, rgb, m1, d1, m, s,
                     seqid) = sess.run([
                         errors_t, output, ground, rgb_t, m1_t, d1_t, monitor,
                         summary, s_t
                     ],
                                       feed_dict={is_training: False})
                    print('{}s to run'.format(time.time() - start))
                    rmses[i] = err['rmse']
                    i = i + 1
                    val_errors.update(err)
                    print('{}/{} {}'.format(batchnum, num_batches,
                                            val_errors.update_log_string(err)))
                    for key, value in m.items():
                        print('{}: {}'.format(key, value))
                    if png_output:
                        ID = os.path.basename(seqid[0].decode())
                        filename = os.path.join(png_output_dir, ID)
                        out = np.round(np.squeeze(pred[0, :, :, 0]) * 256.0)
                        out = out.astype(np.int32)
                        Image.fromarray(out).save(filename, bits=16)

                    if not no_vis_output:
                        vis_log = {
                            'rgb': rgb[0, :, :, :],
                            'd1': m1[0, :, :, :] * d1[0, :, :, :],
                            'ground': g[0, :, :, :],
                            'pred': pred[0, :, :, :]
                        }
                        if 'm' in s:
                            vis_log['m0'] = s['m'][0]
                        if err['rmse'] < best_batch:
                            best_batch = err['rmse']
                            filename = os.path.join(result_dir,
                                                    'val_best.pickle')
                            with open(filename, 'wb') as f:
                                pickle.dump(vis_log, f)
                        if err['rmse'] > worst_batch:
                            worst_batch = err['rmse']
                            filename = os.path.join(result_dir,
                                                    'val_worst.pickle')
                            with open(filename, 'wb') as f:
                                pickle.dump(vis_log, f)
                        if batchnum % val_output_every == 0:
                            filename = os.path.join(
                                result_dir,
                                'val_output-{:04}.pickle'.format(batchnum))
                            with open(filename, 'wb') as f:
                                pickle.dump(vis_log, f)
                    batchnum += 1
                except tf.errors.OutOfRangeError:
                    break
            val_errors.log()
            if val_errors.get('rmse') < best_rmse and not val_only:
                best_epoch = i
                best_rmse = val_errors.get('rmse')
                save_path = saver.save(
                    sess, os.path.join(result_dir, 'best-model.ckpt'))
                print('Best model saved in {}'.format(save_path))
            with open(os.path.join(result_dir, 'errors.pickle'), 'wb') as f:
                pickle.dump(rmses, f)
            print('Validation RMSE: {}'.format(val_errors.get('rmse')))
def test_sintel(restore_model_dir,
                model_name,
                start_step,
                end_step,
                checkpoint_interval,
                dataset_config={},
                is_scale=True,
                is_write_summmary=True,
                num_parallel_calls=4,
                network_mode='v1'):
    dataset_name = 'Sintel'
    dataset = SintelDataset(data_list_file=dataset_config['data_list_file'],
                            img_dir=dataset_config['img_dir'],
                            dataset_type=dataset_config['dataset_type'])
    iterator = dataset.create_evaluation_iterator(
        dataset.data_list, num_parallel_calls=num_parallel_calls)
    batch_img0, batch_img1, batch_img2, batch_flow, batch_mask, batch_occlusion = iterator.get_next(
    )
    batch_mask_occ = tf.multiply(batch_occlusion, batch_mask)
    batch_mask_noc = tf.multiply(1 - batch_occlusion, batch_mask)
    batch_flow_norm = tf.norm(batch_flow, axis=-1, keepdims=True)
    batch_mask_s0_10 = tf.cast(
        tf.logical_and(batch_flow_norm >= 0., batch_flow_norm < 10.),
        tf.float32)
    batch_mask_s0_10 = tf.multiply(batch_mask_s0_10, batch_mask)
    batch_mask_s10_40 = tf.cast(
        tf.logical_and(batch_flow_norm >= 10., batch_flow_norm < 40.),
        tf.float32)
    batch_mask_s10_40 = tf.multiply(batch_mask_s10_40, batch_mask)
    batch_mask_s40_plus = tf.cast(batch_flow_norm >= 40., tf.float32)
    batch_mask_s40_plus = tf.multiply(batch_mask_s40_plus, batch_mask)

    num_mask = tf.reduce_sum(batch_mask)
    num_mask_occ = tf.reduce_sum(batch_mask_occ)
    num_mask_noc = tf.reduce_sum(batch_mask_noc)
    num_mask_s0_10 = tf.reduce_sum(batch_mask_s0_10)
    num_mask_s10_40 = tf.reduce_sum(batch_mask_s10_40)
    num_mask_s40_plus = tf.reduce_sum(batch_mask_s40_plus)

    data_num = dataset.data_num

    flow_estimated, _ = pyramid_processing(batch_img0,
                                           batch_img1,
                                           batch_img2,
                                           train=False,
                                           trainable=False,
                                           is_scale=is_scale,
                                           network_mode=network_mode)

    restore_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    diff = flow_estimated['full_res'] - batch_flow
    EPE_loss, _ = epe_loss(diff, batch_mask)
    MSE_loss, _ = mse_loss(diff, batch_mask)
    ABS_loss, _ = abs_loss(diff, batch_mask)
    Fl_all = compute_Fl(batch_flow, flow_estimated['full_res'], batch_mask)
    EPE_loss_matched, _ = epe_loss(diff, batch_mask_noc)
    EPE_loss_matched = tf.where(
        tf.reduce_sum(batch_mask_noc) < 1., 0., EPE_loss_matched)
    EPE_loss_unmatched, _ = epe_loss(diff, batch_mask_occ)
    EPE_loss_unmatched = tf.where(
        tf.reduce_sum(batch_mask_occ) < 1., 0., EPE_loss_unmatched)
    EPE_loss_s0_10, _ = epe_loss(diff, batch_mask_s0_10)
    EPE_loss_s0_10 = tf.where(
        tf.reduce_sum(batch_mask_s0_10) < 1., 0., EPE_loss_s0_10)
    EPE_loss_s10_40, _ = epe_loss(diff, batch_mask_s10_40)
    EPE_loss_s10_40 = tf.where(
        tf.reduce_sum(batch_mask_s10_40) < 1., 0., EPE_loss_s10_40)
    EPE_loss_s40_plus, _ = epe_loss(diff, batch_mask_s40_plus)
    EPE_loss_s40_plus = tf.where(
        tf.reduce_sum(batch_mask_s40_plus) < 1, 0., EPE_loss_s40_plus)

    summary_writer = tf.summary.FileWriter(
        logdir='/'.join([dataset_name, 'summary', 'test', model_name]))
    saver = tf.train.Saver(var_list=restore_vars)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    sess.run(iterator.initializer)
    steps = np.arange(start_step,
                      end_step + 1,
                      checkpoint_interval,
                      dtype='int32')
    for step in steps:
        saver.restore(sess,
                      '%s/%s/model-%d' % (restore_model_dir, model_name, step))
        EPE = np.zeros([data_num])
        MSE = np.zeros([data_num])
        ABS = np.zeros([data_num])
        Fl = np.zeros([data_num])
        EPE_matched = np.zeros([data_num])
        EPE_unmatched = np.zeros([data_num])
        EPE_s0_10 = np.zeros([data_num])
        EPE_s10_40 = np.zeros([data_num])
        EPE_s40_plus = np.zeros([data_num])

        np_num_mask = np.zeros([data_num])
        np_num_mask_occ = np.zeros([data_num])
        np_num_mask_noc = np.zeros([data_num])
        np_num_mask_s0_10 = np.zeros([data_num])
        np_num_mask_s10_40 = np.zeros([data_num])
        np_num_mask_s40_plus = np.zeros([data_num])
        start_time = time.time()
        for i in range(data_num):
            EPE[i], MSE[i], ABS[i], Fl[i], EPE_matched[i], EPE_unmatched[i], EPE_s0_10[i], EPE_s10_40[i], EPE_s40_plus[i], \
                np_num_mask[i], np_num_mask_occ[i], np_num_mask_noc[i], np_num_mask_s0_10[i], np_num_mask_s10_40[i], np_num_mask_s40_plus[i] = sess.run(
                [EPE_loss, MSE_loss, ABS_loss, Fl_all, EPE_loss_matched, EPE_loss_unmatched, EPE_loss_s0_10, EPE_loss_s10_40, EPE_loss_s40_plus,
                 num_mask, num_mask_occ, num_mask_noc, num_mask_s0_10, num_mask_s10_40, num_mask_s40_plus])

        mean_time = (time.time() - start_time) / data_num
        mean_EPE = np.mean(EPE)
        mean_MSE = np.mean(MSE)
        mean_ABS = np.mean(ABS)
        mean_Fl = np.mean(Fl)
        mean_EPE_matched = np.mean(EPE_matched)
        mean_EPE_unmatched = np.mean(EPE_unmatched)
        mean_EPE_s0_10 = np.mean(EPE_s0_10)
        mean_EPE_s10_40 = np.mean(EPE_s10_40)
        mean_EPE_s40_plus = np.mean(EPE_s40_plus)

        weighted_mean_EPE = np.sum(np.multiply(
            EPE, np_num_mask)) / np.sum(np_num_mask)
        weighted_mean_MSE = np.sum(np.multiply(
            MSE, np_num_mask)) / np.sum(np_num_mask)
        weighted_mean_ABS = np.sum(np.multiply(
            ABS, np_num_mask)) / np.sum(np_num_mask)
        weighted_mean_Fl = np.sum(np.multiply(
            Fl, np_num_mask)) / np.sum(np_num_mask)
        weighted_mean_EPE_matched = np.sum(
            np.multiply(EPE_matched,
                        np_num_mask_noc)) / np.sum(np_num_mask_noc)
        weighted_mean_EPE_unmatched = np.sum(
            np.multiply(EPE_unmatched,
                        np_num_mask_occ)) / np.sum(np_num_mask_occ)
        weighted_mean_EPE_s0_10 = np.sum(
            np.multiply(EPE_s0_10,
                        np_num_mask_s0_10)) / np.sum(np_num_mask_s0_10)
        weighted_mean_EPE_s10_40 = np.sum(
            np.multiply(EPE_s10_40,
                        np_num_mask_s10_40)) / np.sum(np_num_mask_s10_40)
        weighted_mean_EPE_s40_plus = np.sum(
            np.multiply(EPE_s40_plus,
                        np_num_mask_s40_plus)) / np.sum(np_num_mask_s40_plus)

        print(
            'step %d: EPE: %.6f, mse: %.6f, abs: %.6f, Fl: %.6f, EPE_matched: %.6f, EPE_unmatched: %.6f, EPE_s0_10: %.6f, EPE_s10_40: %.6f, EPE_s40_plus: %.6f, \n \
               weighted_EPE: %.6f, weighted_MSE: %.6f, weighted_ABS: %.6f, weighted_Fl: %.6f, weighted_EPE_matched: %.6f, weighted_EPE_unmatched: %.6f, \n \
               weighted_EPE_s0_10: %.6f, weighted_EPE_s10_40: %.6f, weighted_EPE_s40_plus: %.6f, time_cost: %.6f'
            %
            (step, mean_EPE, mean_MSE, mean_ABS, mean_Fl, mean_EPE_matched,
             mean_EPE_unmatched, mean_EPE_s0_10, mean_EPE_s10_40,
             mean_EPE_s40_plus, weighted_mean_EPE, weighted_mean_MSE,
             weighted_mean_ABS, weighted_mean_Fl, weighted_mean_EPE_matched,
             weighted_mean_EPE_unmatched, weighted_mean_EPE_s0_10,
             weighted_mean_EPE_s10_40, weighted_mean_EPE_s40_plus, mean_time))

        #print('step %d: EPE: %.6f, mse: %.6f, abs: %.6f, Fl: %.6f, EPE_matched: %.6f, EPE_unmatched: %.6f, EPE_s0_10: %.6f, EPE_s10_40: %.6f, EPE_s40_plus: %.6f, time_cost: %.6f' %
        #(step, mean_EPE, mean_MSE, mean_ABS, mean_Fl, mean_EPE_matched, mean_EPE_unmatched, mean_EPE_s0_10, mean_EPE_s10_40, mean_EPE_s40_plus, mean_time))

        if is_write_summmary:
            summary = tf.Summary()
            summary.value.add(tag='EPE', simple_value=mean_EPE)
            summary.value.add(tag='mse', simple_value=mean_MSE)
            summary.value.add(tag='abs', simple_value=mean_ABS)
            summary.value.add(tag='Fl', simple_value=mean_Fl)
            summary.value.add(tag='EPE_matched', simple_value=mean_EPE_matched)
            summary.value.add(tag='EPE_unmatched',
                              simple_value=mean_EPE_unmatched)
            summary.value.add(tag='EPE_s0_10', simple_value=mean_EPE_s0_10)
            summary.value.add(tag='EPE_s10_40', simple_value=mean_EPE_s10_40)
            summary.value.add(tag='EPE_s40_plus',
                              simple_value=mean_EPE_s40_plus)
            summary.value.add(tag='time_cost', simple_value=mean_time)
            summary_writer.add_summary(summary, global_step=step)
Exemplo n.º 7
0
    def evaluate_text_classify(self, model, valid_loader):
        total_loss = 0
        total_steps = 0
        total_samples = 0
        hit_num = 0
        total_num = 0

        logits_list = list()
        y_trues = list()

        total_spent_time = 0.0
        for _step, batch in enumerate(valid_loader):
            batch = {
                key: val.cuda() if isinstance(val, torch.Tensor) else val
                for key, val in batch.items()
            }

            infer_start_time = time.time()
            with torch.no_grad():
                student_outputs = model(batch)
            infer_end_time = time.time()
            total_spent_time += infer_end_time - infer_start_time

            assert "logits" in student_outputs and "label_ids" in batch
            logits, label_ids = student_outputs["logits"], batch["label_ids"]

            y_trues.extend(label_ids.tolist())
            logits_list.extend(logits.tolist())
            hit_num += torch.sum(
                torch.argmax(logits, dim=-1) == label_ids).item()
            total_num += label_ids.shape[0]

            if len(logits.shape) == 1 or logits.shape[-1] == 1:
                tmp_loss = losses.mse_loss(logits, label_ids)
            elif len(logits.shape) == 2:
                tmp_loss = losses.cross_entropy(logits, label_ids)
            else:
                raise RuntimeError

            total_loss += tmp_loss.mean().item()
            total_steps += 1
            total_samples += valid_loader.batch_size
            if (_step + 1) % 100 == 0:
                logger.info("Eval: %d/%d steps finished" %
                            (_step + 1, len(valid_loader.dataset) //
                             valid_loader.batch_size))

        logger.info("Inference time = {:.2f}s, [{:.4f} ms / sample] ".format(
            total_spent_time, total_spent_time * 1000 / total_samples))

        eval_loss = total_loss / total_steps
        logger.info("Eval loss: {}".format(eval_loss))

        logits_list = np.array(logits_list)
        eval_outputs = list()
        for metric in self.metrics:
            if metric.endswith("accuracy"):
                acc = hit_num / total_num
                logger.info("Accuracy: {}".format(acc))
                eval_outputs.append(("accuracy", acc))
            elif metric == "f1":
                f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1))
                logger.info("F1: {}".format(f1))
                eval_outputs.append(("f1", f1))
            elif metric == "macro-f1":
                f1 = f1_score(y_trues,
                              np.argmax(logits_list, axis=-1),
                              average="macro")
                logger.info("Macro F1: {}".format(f1))
                eval_outputs.append(("macro-f1", f1))
            elif metric == "micro-f1":
                f1 = f1_score(y_trues,
                              np.argmax(logits_list, axis=-1),
                              average="micro")
                logger.info("Micro F1: {}".format(f1))
                eval_outputs.append(("micro-f1", f1))
            elif metric == "auc":
                auc = roc_auc_score(y_trues, np.argmax(logits_list, axis=-1))
                logger.info("AUC: {}".format(auc))
                eval_outputs.append(("auc", auc))
            elif metric == "matthews_corrcoef":
                mcc = matthews_corrcoef(y_trues, np.argmax(logits_list,
                                                           axis=-1))
                logger.info("Matthews Corrcoef: {}".format(mcc))
                eval_outputs.append(("matthews_corrcoef", mcc))
            elif metric == "pearson_and_spearman":
                preds = logits_list[:, 0]
                pearson_corr = pearsonr(preds, y_trues)[0]
                spearman_corr = spearmanr(preds, y_trues)[0]
                logger.info("Peasrson: {}".format(pearson_corr))
                logger.info("Spearmanr: {}".format(spearman_corr))
                corr = (pearson_corr + spearman_corr) / 2.0
                logger.info("Peasrson_and_spearmanr: {}".format(corr))
                eval_outputs.append(("pearson_and_spearman", corr))
            elif metric == "classification_report":
                logger.info("\n{}".format(
                    classification_report(y_trues,
                                          np.argmax(logits_list, axis=-1),
                                          digits=4)))
            elif "last_layer_mse" in self.metrics:
                logger.info("Last layer MSE: {}".format(eval_loss))
                eval_outputs.append(("last_layer_mse", -eval_loss))
            else:
                raise NotImplementedError("Metric %s not implemented" % metric)
        return eval_outputs
Exemplo n.º 8
0
def training_loop(train_dataloader, opts):
    """Runs the training loop.
        * Saves checkpoints every opts.checkpoint_every iterations
        * Saves generated samples every opts.sample_every iterations
    """

    # Create generators and discriminators
    G, D = create_model(opts)

    # Create optimizers for the generators and discriminators
    g_optimizer = optim.Adam(G.parameters(), opts.lr, [opts.beta1, opts.beta2])
    d_optimizer = optim.Adam(D.parameters(), opts.lr, [opts.beta1, opts.beta2])

    # Generate fixed noise for sampling from the generator
    fixed_noise = sample_noise(
        opts.noise_size)  # batch_size x noise_size x 1 x 1

    iteration = 1

    total_train_iters = opts.num_epochs * len(train_dataloader)

    for epoch in range(opts.num_epochs):

        for batch in train_dataloader:

            real_images, labels = batch
            real_images, labels = utils.to_var(real_images), utils.to_var(
                labels).long().squeeze()

            ################################################
            ###         TRAIN THE DISCRIMINATOR         ####
            ################################################

            d_optimizer.zero_grad()

            # FILL THIS IN
            # 1. Compute the discriminator loss on real images
            D_real_loss = mse_loss(D(real_images), 1) / 2

            # 2. Sample noise
            noise = sample_noise(opts.noise_size)

            # 3. Generate fake images from the noise
            fake_images = G(noise)

            # 4. Compute the discriminator loss on the fake images
            D_fake_loss = mse_loss(D(fake_images), 0) / 2

            # 5. Compute the total discriminator loss
            D_total_loss = D_real_loss + D_fake_loss

            D_total_loss.backward()
            d_optimizer.step()

            ###########################################
            ###          TRAIN THE GENERATOR        ###
            ###########################################
            first_iter = True
            G_steps = 0

            while first_iter or G_loss.data[0] > opts.gen_loss_threshold:
                g_optimizer.zero_grad()

                # FILL THIS IN
                # 1. Sample noise
                noise = sample_noise(opts.noise_size)

                # 2. Generate fake images from the noise
                fake_images = G(noise)

                # 3. Compute the generator loss
                G_loss = mse_loss(D(fake_images), 1)

                G_loss.backward()
                g_optimizer.step()
                first_iter = False
                G_steps += 1

            # Print the log info
            if iteration % opts.log_step == 0:
                print(
                    'Iteration [{:4d}/{:4d}] | D_real_loss: {:6.4f} | D_fake_loss: {:6.4f} | G_loss: {:6.4f} | G_steps: {:4d}'
                    .format(iteration, total_train_iters, D_real_loss.data[0],
                            D_fake_loss.data[0], G_loss.data[0], G_steps))

            # Save the generated samples
            if iteration % opts.sample_every == 0:
                save_samples(G, fixed_noise, iteration, opts)

            # Save the model parameters
            if iteration % opts.checkpoint_every == 0:
                checkpoint(iteration, G, D, opts)

            iteration += 1
Exemplo n.º 9
0
generator_optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)

print('Encoder')
encoder.summary()
print('Generator')
generator.summary()

#
# Define losses
#

l_reg_z = losses.reg_loss(z_mean, z_log_var)
l_reg_zr_ng = losses.reg_loss(zr_mean_ng, zr_log_var_ng)
l_reg_zpp_ng = losses.reg_loss(zpp_mean_ng, zpp_log_var_ng)

l_ae = losses.mse_loss(encoder_input, xr, args.original_shape)
l_ae2 = losses.mse_loss(encoder_input, xr_latent, args.original_shape)

encoder_l_adv = l_reg_z + args.alpha * K.maximum(
    0., args.m - l_reg_zr_ng) + args.alpha * K.maximum(0.,
                                                       args.m - l_reg_zpp_ng)
encoder_loss = encoder_l_adv + args.beta * l_ae

l_reg_zr = losses.reg_loss(zr_mean, zr_log_var)
l_reg_zpp = losses.reg_loss(zpp_mean, zpp_log_var)

generator_l_adv = args.alpha * l_reg_zr + args.alpha * l_reg_zpp
generator_loss = generator_l_adv + args.beta * l_ae2

#
# Define training step operations
def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        ltoday, mtoday, htoday, tomorrow, _, _, _, _, _ = rec.data_inputs(
            FLAGS.train_input_path, FLAGS.train_batch_size, conf.shape_dict,
            30, False, False)
        predictions, _, _, _ = cnn_branches.cnn_with_branch(
            ltoday, mtoday, htoday, conf.HEIGHT * conf.HIGH_WIDTH,
            FLAGS.train_batch_size)
        reality = tf.reshape(tomorrow, predictions.get_shape())
        mse = losses.mse_loss(predictions, reality)
        loss = losses.total_loss(predictions, reality, losses.main_loss)
        train_step = ut.train(loss, global_step,
                              conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN)
        saver = tf.train.Saver(tf.global_variables())
        summary_op = tf.summary.merge_all()

        init = tf.global_variables_initializer()
        coord = tf.train.Coordinator()
        sess = tf.Session()
        #tf_debug.add_debug_tensor_watch(sess,'l_conv1')
        #sess = tf_debug.LocalCLIDebugWrapperSession(sess,)
        sess.run(init)

        tf.train.start_queue_runners(sess=sess, coord=coord)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        loss_list = []
        mse_list = []
        total_loss_list = []

        for step in xrange(FLAGS.epoch *
                           conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN + 1):
            start_time = time.time()
            _, loss_val, mse_loss = sess.run([train_step, loss, mse])
            duration = time.time() - start_time

            assert not np.isnan(loss_val), 'Model diverged with loss = NaN'
            loss_list.append(loss_val)
            mse_list.append(mse_loss)

            if step % conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN == 0:
                num_examples_per_step = FLAGS.train_batch_size
                examples_per_sec = 0  #num_examples_per_step / duration
                sec_per_batch = float(duration)
                average_loss_value = np.mean(loss_list)
                average_mse_value = np.mean(mse_list)
                total_loss_list.append(average_loss_value)
                loss_list.clear()
                format_str = (
                    '%s: epoch %d, loss = %.4f , mse = %.4f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(
                    format_str %
                    (datetime.now(), step /
                     conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN, average_loss_value,
                     average_mse_value, examples_per_sec, sec_per_batch))
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)
            if step % (conf.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 10 + 1) == 0:
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

        matlab.save_matrix(FLAGS.train_dir + 'cnn_branch_loss.mat',
                           total_loss_list, 'cnn_branch_loss')