Ejemplo n.º 1
0
def main():
    parser = make_standard_parser('Random Projection RL experiments.',
                                  arch_choices=arch_choices,
                                  skip_train=True,
                                  skip_val=True)

    parser.add_argument('--vsize',
                        type=int,
                        default=100,
                        help='Dimension of intrinsic parmaeter space.')
    parser.add_argument('--depth',
                        type=int,
                        default=2,
                        help='Number of layers in FNN.')
    parser.add_argument('--width',
                        type=int,
                        default=100,
                        help='Width of layers in FNN.')
    parser.add_argument('--env_name',
                        type=str,
                        default=arch_choices[0],
                        choices=env_choices,
                        help='Which architecture to use (choices: %s).' %
                        env_choices)
    args = parser.parse_args()

    proj_type = 'dense'

    if args.arch == 'fourier':
        d_Fourier = 200

    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)

    # Get a TF session registered with Keras and set numpy and TF seeds
    sess = setup_session_and_seeds(args.seed)

    # 0. LOAD ENV
    theta_r = 2.0
    theta_threshold_radians = 12 * 2 * math.pi / 360
    x_threshold = 1

    env_name = args.env_name
    env = gym.make(env_name)

    #sess = tf.Session()
    state_dim = env.observation_space.shape[0]
    if env_name == 'Pendulum-v0':
        num_actions = 1
    else:
        num_actions = env.action_space.n

    # 1. CREATE MODEL
    extra_feed_dict = {}
    with WithTimer('Make model'):
        if args.arch == 'fc_dir':
            model = build_model_fc_dir(state_dim,
                                       num_actions,
                                       weight_decay=args.l2,
                                       depth=args.depth,
                                       width=args.width,
                                       shift_in=None)
        elif args.arch == 'fc':
            model = build_model_fc(state_dim,
                                   num_actions,
                                   weight_decay=args.l2,
                                   vsize=args.vsize,
                                   depth=args.depth,
                                   width=args.width,
                                   shift_in=None,
                                   proj_type='dense')
        elif args.arch == 'fourier':
            model = build_model_fourier(d_Fourier,
                                        num_actions,
                                        weight_decay=args.l2,
                                        depth=args.depth,
                                        width=args.width,
                                        shift_in=None)
        else:
            raise Exception('Unknown network architecture: %s' % args.arch)

    print 'All model weights:'
    summarize_weights(model.trainable_weights)
    print 'Model summary:'
    model.summary()
    model.print_trainable_warnings()

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    if args.opt == 'sgd':
        optimizer = tf.train.MomentumOptimizer(args.lr, args.mom)
    elif args.opt == 'rmsprop':
        optimizer = tf.train.RMSPropOptimizer(args.lr, momentum=args.mom)
    elif args.opt == 'adam':
        optimizer = tf.train.AdamOptimizer(args.lr, args.beta1, args.beta2)

    summarize_opt(optimizer)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()
    buddy.tic()  # call if new run OR resumed run

    # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(uninitialized_vars,
                                                'init_missed_vars')

    sess.run(init_missed_vars)

    # Print warnings about any TF vs. Keras shape mismatches
    warn_misaligned_shapes(model)
    # Make sure all variables, which are model variables, have been initialized (e.g. model params and model running BN means)
    tf_assert_all_init(sess)

    # 3.5 Normalize the overall basis matrix across the (multiple) unnormalized basis matrices for each layer
    basis_matrices = []
    normalizers = []

    for layer in model.layers:
        try:
            basis_matrices.extend(layer.offset_creator.basis_matrices)
        except AttributeError:
            continue
        try:
            normalizers.extend(layer.offset_creator.basis_matrix_normalizers)
        except AttributeError:
            continue

    if len(basis_matrices) > 0 and not args.load:

        if proj_type == 'sparse':
            # Norm of overall basis matrix rows (num elements in each sum == total parameters in model)
            bm_row_norms = tf.sqrt(
                tf.add_n([
                    tf.sparse_reduce_sum(tf.square(bm), 1)
                    for bm in basis_matrices
                ]))
            # Assign `normalizer` Variable to these row norms to achieve normalization of the basis matrix
            # in the TF computational graph
            rescale_basis_matrices = [
                tf.assign(var, tf.reshape(bm_row_norms, var.shape))
                for var in normalizers
            ]
            _ = sess.run(rescale_basis_matrices)
        elif proj_type == 'dense':
            bm_sums = [
                tf.reduce_sum(tf.square(bm), 1) for bm in basis_matrices
            ]
            divisor = tf.expand_dims(tf.sqrt(tf.add_n(bm_sums)), 1)
            rescale_basis_matrices = [
                tf.assign(var, var / divisor) for var in basis_matrices
            ]
            _ = sess.run(rescale_basis_matrices)
        else:
            print '\nhere\n'
            embed()

            assert False, 'what to do with fastfood?'

    # 3.5 Fourier features of the observations if required
    if args.arch == 'fourier':
        RP = np.random.randn(d_Fourier, state_dim)
        Rb = np.random.uniform(-math.pi, math.pi, d_Fourier)
        state_dim = d_Fourier

    # 4. SETUP TENSORBOARD LOGGING

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(
            args.output + "/{}-experiment-1".format(env_name), sess.graph)

    observation_to_action = model

    # 5. TRAIN
    if env_name == 'CartPole-v0':

        q_learner = NeuralQLearner(
            sess,
            optimizer,
            observation_to_action,
            state_dim,
            num_actions,
            init_exp=0.5,
            anneal_steps=10000,  # N steps for annealing exploration
            discount_factor=0.95,
            batch_size=32,
            target_update_rate=0.01,
            summary_writer=writer)

        MAX_EPISODES = 10000
        MAX_STEPS = 200

        episode_history = deque(maxlen=100)

        if args.ipy:
            print 'Embed: before train / val loop (Ctrl-D to continue)'
            embed()

        episode_history = deque(maxlen=100)
        for i_episode in range(MAX_EPISODES):

            # initialize
            state = env.reset()
            if args.arch == 'fourier':
                state = np.sin(RP.dot(state) +
                               Rb)  # np.median( RP.dot(state) )

            total_rewards = 0

            for t in range(MAX_STEPS):
                # env.render()
                action = q_learner.eGreedyAction(state[np.newaxis, :])
                next_state, reward, done, _ = env.step(action)

                if args.arch == 'fourier':
                    next_state = np.sin(RP.dot(next_state) +
                                        Rb)  # np.median( RP.dot(state) )

                total_rewards += reward
                # reward = -10 if done else 0.1 # normalize reward
                q_learner.storeExperience(state, action, reward, next_state,
                                          done)

                q_learner.updateModel()
                state = next_state

                if done: break

            episode_history.append(total_rewards)
            mean_rewards = np.mean(episode_history)

            print("Episode {}".format(i_episode))
            print("Finished after {} timesteps".format(t + 1))
            print("Reward for this episode: {}".format(total_rewards))
            print("Average reward for last 100 episodes: {:.2f}".format(
                mean_rewards))

            if mean_rewards >= 195.0:
                print("Environment {} solved after {} episodes".format(
                    env_name, i_episode + 1))
                break

    elif env_name == 'CartPole-v1':

        q_learner = NeuralQLearner(
            sess,
            optimizer,
            observation_to_action,
            state_dim,
            num_actions,
            init_exp=0.5,
            anneal_steps=10000,  # N steps for annealing exploration
            discount_factor=0.95,
            batch_size=32,
            target_update_rate=0.01,
            summary_writer=writer)

        MAX_EPISODES = 10000
        MAX_STEPS = 500

        episode_history = deque(maxlen=100)

        if args.ipy:
            print 'Embed: before train / val loop (Ctrl-D to continue)'
            embed()

        episode_history = deque(maxlen=100)
        for i_episode in range(MAX_EPISODES):

            # initialize
            state = env.reset()
            total_rewards = 0

            for t in itertools.count():
                # env.render()
                action = q_learner.eGreedyAction(state[np.newaxis, :])
                next_state, reward, done, _ = env.step(action)

                total_rewards += reward
                # reward = -10 if done else 0.1 # normalize reward
                q_learner.storeExperience(state, action, reward, next_state,
                                          done)

                q_learner.updateModel()
                state = next_state

                if done: break

            episode_history.append(total_rewards)
            mean_rewards = np.mean(episode_history)

            print("Episode {}".format(i_episode))
            print("Finished after {} timesteps".format(t + 1))
            print("Reward for this episode: {}".format(total_rewards))
            print("Average reward for last 100 episodes: {:.2f}".format(
                mean_rewards))

            if mean_rewards >= 475.0:
                print("Environment {} solved after {} episodes".format(
                    env_name, i_episode + 1))
                break

    elif env_name == 'MountainCar-v0':

        q_learner = NeuralQLearner(
            sess,
            optimizer,
            observation_to_action,
            state_dim,
            num_actions,
            batch_size=64,
            anneal_steps=10000,  # N steps for annealing exploration
            replay_buffer_size=1000000,
            discount_factor=0.95,  # discount future rewards
            target_update_rate=0.01,
            reg_param=0.01,  # regularization constants
            summary_writer=writer)

        MAX_EPISODES = 10000

        episode_history = deque(maxlen=100)

        if args.ipy:
            print 'Embed: before train / val loop (Ctrl-D to continue)'
            embed()

        episode_history = deque(maxlen=100)
        for i_episode in range(MAX_EPISODES):

            # initialize
            state = env.reset()
            total_rewards = 0
            for t in itertools.count():
                # env.render()
                action = q_learner.eGreedyAction(state[np.newaxis, :])
                next_state, reward, done, _ = env.step(action)

                total_rewards += reward
                # reward = -10 if done else 0.1 # normalize reward
                q_learner.storeExperience(state, action, reward, next_state,
                                          done)

                q_learner.updateModel()
                state = next_state

                if done: break

            episode_history.append(total_rewards)
            mean_rewards = np.mean(episode_history)

            print("Episode {}".format(i_episode))
            print("Finished after {} timesteps".format(t + 1))
            print("Reward for this episode: {}".format(total_rewards))
            print("Average reward for last 100 episodes: {:.2f}".format(
                mean_rewards))

            if mean_rewards >= -110.0:
                print("Environment {} solved after {} episodes".format(
                    env_name, i_episode + 1))
                break

    elif env_name == 'CartPole-v2' or env_name == 'CartPole-v3' or env_name == 'CartPole-v4' or env_name == 'CartPole-v5':

        q_learner = NeuralQLearner(
            sess,
            optimizer,
            observation_to_action,
            state_dim,
            num_actions,
            init_exp=0.5,
            anneal_steps=10000,  # N steps for annealing exploration
            discount_factor=0.95,
            batch_size=32,
            target_update_rate=0.01,
            summary_writer=writer)

        MAX_EPISODES = 10000
        MAX_STEPS = 200

        episode_history = deque(maxlen=100)

        if args.ipy:
            print 'Embed: before train / val loop (Ctrl-D to continue)'
            embed()

        episode_history = deque(maxlen=100)
        for i_episode in range(MAX_EPISODES):

            # initialize
            state = env.reset()

            if args.arch == 'fourier':
                state = np.sin(RP.dot(state) +
                               Rb)  # np.median( RP.dot(state) )

            total_rewards = 0

            for t in range(MAX_STEPS):
                # env.render()
                action = q_learner.eGreedyAction(state[np.newaxis, :])

                next_state, reward, done, _ = env.step(action)

                if args.arch == 'fourier':
                    next_state = np.sin(RP.dot(next_state) +
                                        Rb)  # np.median( RP.dot(state) )

                total_rewards += reward
                # reward = -10 if done else 0.1 # normalize reward
                q_learner.storeExperience(state, action, reward, next_state,
                                          done)

                q_learner.updateModel()
                state = next_state

                if done: break

            episode_history.append(total_rewards)
            mean_rewards = np.mean(episode_history)

            print("Episode {}".format(i_episode))
            print("Finished after {} timesteps".format(t + 1))
            print("Reward for this episode: {}".format(total_rewards))
            print("Average reward for last 100 episodes: {:.2f}".format(
                mean_rewards))

            if mean_rewards >= 195.0:
                print("Environment {} solved after {} episodes".format(
                    env_name, i_episode + 1))
                break

    elif env_name == 'Pendulum-v0':

        q_learner = NeuralQLearner(
            sess,
            optimizer,
            observation_to_action,
            state_dim,
            num_actions,
            init_exp=0.5,
            anneal_steps=10000,  # N steps for annealing exploration
            discount_factor=0.99,
            batch_size=32,
            target_update_rate=0.01,
            summary_writer=writer)

        MAX_EPISODES = 10000
        MAX_STEPS = 200

        episode_history = deque(maxlen=100)

        if args.ipy:
            print 'Embed: before train / val loop (Ctrl-D to continue)'
            embed()

        episode_history = deque(maxlen=100)
        for i_episode in range(MAX_EPISODES):

            # initialize
            state = env.reset()
            if args.arch == 'fourier':
                state = np.sin(RP.dot(state) +
                               Rb)  # np.median( RP.dot(state) )

            total_rewards = 0

            for t in range(MAX_STEPS):
                # env.render()
                action = q_learner.eGreedyAction(state[np.newaxis, :])
                next_state, reward, done, _ = env.step(action)
                if args.arch == 'fourier':
                    next_state = np.sin(RP.dot(next_state) +
                                        Rb)  # np.median( RP.dot(state) )

                total_rewards += reward
                # reward = -10 if done else 0.1 # normalize reward
                q_learner.storeExperience(state, action, reward, next_state,
                                          done)

                q_learner.updateModel()
                state = next_state

                if done: break

            episode_history.append(total_rewards)
            mean_rewards = np.mean(episode_history)

            print("Episode {}".format(i_episode))
            print("Finished after {} timesteps".format(t + 1))
            print("Reward for this episode: {}".format(total_rewards))
            print("Average reward for last 100 episodes: {:.2f}".format(
                mean_rewards))

            if mean_rewards >= 195.0:
                print("Environment {} solved after {} episodes".format(
                    env_name, i_episode + 1))
                break
Ejemplo n.º 2
0
def main():
    parser = make_standard_parser(
        'Distributed Training of Direct or RProj model on Imagenet',
        arch_choices=arch_choices)

    parser.add_argument('--vsize',
                        type=int,
                        default=100,
                        help='Dimension of intrinsic parmaeter space.')
    parser.add_argument('--minibatch',
                        '--mb',
                        type=int,
                        default=256,
                        help='Size of minibatch.')
    parser.add_argument('--denseproj',
                        action='store_true',
                        help='Use a dense projection.')
    parser.add_argument('--sparseproj',
                        action='store_true',
                        help='Use a sparse projection.')
    parser.add_argument('--fastfoodproj',
                        action='store_true',
                        help='Use a fastfood projection.')

    args = parser.parse_args()

    minibatch_size = args.minibatch
    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)

    n_proj_specified = sum(
        [args.denseproj, args.sparseproj, args.fastfoodproj])
    if args.arch in arch_choices_projected:
        assert n_proj_specified == 1, 'Arch "%s" requires projection. Specify exactly one of {denseproj, sparseproj, fastfoodproj} options.' % args.arch
    else:
        assert n_proj_specified == 0, 'Arch "%s" does not require projection, so do not specify any of {denseproj, sparseproj, fastfoodproj} options.' % args.arch

    if args.denseproj:
        proj_type = 'dense'
    elif args.sparseproj:
        proj_type = 'sparse'
    else:
        proj_type = 'fastfood'

    # Initialize Horovod
    hvd.init()

    #minibatch_size = 256
    worker_minibatch_size = minibatch_size / hvd.size()

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    #config.log_device_placement=True
    my_rank = hvd.local_rank()
    print "I am worker ", my_rank

    config.gpu_options.visible_device_list = str(hvd.local_rank())
    K.set_session(tf.Session(config=config))

    # Adjust number of epochs based on number of GPUs.
    epochs = args.epochs

    # Add hook to broadcast variables from rank 0 to all other processes during
    # initialization.
    #hooks = [hvd.BroadcastGlobalVariablesHook(0)]

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.

    # 0. LOAD DATA
    train_h5 = h5py.File(args.train_h5, 'r')
    train_x = train_h5['images']
    train_y = train_h5['labels']
    val_h5 = h5py.File(args.val_h5, 'r')
    val_x = val_h5['images']
    val_y = val_h5['labels']

    # load into memory if less than 1 GB
    if train_x.size * 4 + val_x.size * 4 < 1e9:
        train_x, train_y = np.array(train_x), np.array(train_y)
        val_x, val_y = np.array(val_x), np.array(val_y)

    # 1. CREATE MODEL
    extra_feed_dict = {}

    with WithTimer('Make model'):
        if args.arch == 'alexnet_dir':
            shift_in = np.array([104, 117, 123], dtype='float32')
            model = build_alexnet_direct(weight_decay=args.l2,
                                         shift_in=shift_in)
            randmirrors = True
            randcrops = True
            cropsize = (227, 227)

        elif args.arch == 'squeeze_dir':
            model = build_squeezenet_direct(weight_decay=args.l2,
                                            shift_in=np.array([104, 117, 123]))
            randmirrors = True
            randcrops = True
            cropsize = (224, 224)

        elif args.arch == 'alexnet':
            if proj_type == 'fastfood':
                model = build_alexnet_fastfood(weight_decay=args.l2,
                                               shift_in=np.array(
                                                   [104, 117, 123]),
                                               vsize=args.vsize)
            else:
                raise Exception('not implemented')
            randmirrors = True
            randcrops = True
            cropsize = (227, 227)

        elif args.arch == 'squeeze':
            if proj_type == 'fastfood':
                model = build_squeezenet_fastfood(weight_decay=args.l2,
                                                  shift_in=np.array(
                                                      [104, 117, 123]),
                                                  vsize=args.vsize)
            else:
                raise Exception('not implemented')
            randmirrors = True
            randcrops = True
            cropsize = (224, 224)

        else:
            raise Exception('Unknown network architecture: %s' % args.arch)

    if my_rank == 0:
        print 'All model weights:'
        summarize_weights(model.trainable_weights)
        print 'Model summary:'
        model.summary()
        model.print_trainable_warnings()

    lr = args.lr

    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(lr, args.beta1, args.beta2)

    # Add Horovod Distributed Optimizer
    opt = hvd.DistributedOptimizer(opt)
    global_step = tf.contrib.framework.get_or_create_global_step()
    train_step = opt.minimize(model.v.loss, global_step=global_step)

    sess = K.get_session()
    sess.run(hvd.broadcast_global_variables(0))

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()
    buddy.tic()  # call if new run OR resumed run

    # 4. SETUP TENSORBOARD LOGGING
    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    val_histogram_summaries = get_collection_intersection_summary(
        'val_collection', 'orig_histogram')
    val_scalar_summaries = get_collection_intersection_summary(
        'val_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    ## 5. TRAIN
    train_iters = (train_y.shape[0] - 1) / minibatch_size
    val_iters = (val_y.shape[0] - 1) / minibatch_size
    impreproc = ImagePreproc()

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        do_log_params = lambda ep, it, ii: True
        do_log_val = lambda ep, it, ii: True
        do_log_train = lambda ep, it, ii: (
            it < train_iters and it & it - 1 == 0 or it >= train_iters and it %
            train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch, buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Evaluate val set performance
        if not args.skipval:
            tic2()
            for ii in xrange(val_iters):
                with WithTimer('(worker %d) val iter %d/%d' %
                               (my_rank, ii, val_iters),
                               quiet=not args.verbose):

                    start_idx = ii * minibatch_size

                    # each worker gets a portion of the minibatch
                    my_start = start_idx + my_rank * worker_minibatch_size
                    my_end = my_start + worker_minibatch_size

                    batch_x = val_x[my_start:my_end]
                    batch_y = val_y[my_start:my_end]

                    #print "**** I am worker %d, my val batch starts %d and ends %d"%(my_rank, my_start, my_end)

                    if randcrops:
                        batch_x = impreproc.center_crops(batch_x, cropsize)
                    feed_dict = {
                        model.v.input_images: batch_x,
                        model.v.input_labels: batch_y,
                        K.learning_phase(): 0,
                    }
                    feed_dict.update(extra_feed_dict)
                    fetch_dict = model.trackable_dict
                    with WithTimer('(worker %d) sess.run val iter' % my_rank,
                                   quiet=not args.verbose):
                        result_val = sess_run_dict(sess,
                                                   fetch_dict,
                                                   feed_dict=feed_dict)

                    buddy.note_weighted_list(
                        batch_x.shape[0],
                        model.trackable_names,
                        [result_val[k] for k in model.trackable_names],
                        prefix='val_')

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'mean_%s' % name: value
                        for name, value in buddy.epoch_mean_list_re('^val_')
                    },
                    prefix='buddy')

            print(
                '\ntime: %f. after training for %d epochs:\n%3d (worker %d) val:   %s (%.3gs/i)'
                % (buddy.toc(), buddy.epoch, buddy.train_iter, my_rank,
                   buddy.epoch_mean_pretty_re(
                       '^val_', style=val_style), toc2() / val_iters))

        # 2. Possiby Snapshot, possibly quit
        # only worker 0 handles it
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            snap_end = buddy.epoch == args.epochs
            if snap_intermed or snap_end:
                # Snapshot
                if my_rank == 0:
                    save_path = saver.save(
                        sess, '%s/%s_%04d.ckpt' %
                        (args.output, args.snapshot_to, buddy.epoch))
                    print 'snappshotted model to', save_path
                    with gzip.open(
                            '%s/%s_misc_%04d.pkl.gz' %
                        (args.output, args.snapshot_to, buddy.epoch),
                            'w') as ff:
                        saved = {'buddy': buddy}
                        pickle.dump(saved, ff)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break  # Extra pass at end: just report val stats and skip training

        # 3. Train on training set
        tic3()
        for ii in xrange(train_iters):
            tic2()

            with WithTimer('(worker %d) train iter %d/%d' %
                           (my_rank, ii, train_iters),
                           quiet=not args.verbose):

                if args.shuffletrain:
                    start_idx = np.random.randint(train_x.shape[0] -
                                                  minibatch_size)
                else:
                    start_idx = ii * minibatch_size

                # each worker gets a portion of the minibatch
                my_start = start_idx + my_rank * worker_minibatch_size
                my_end = my_start + worker_minibatch_size

                #print "**** ii is %d, train_iters is %d"%(ii, train_iters)
                #print "**** I am worker %d, my training batch starts %d and ends %d (total: %d)"%(my_rank, my_start, my_end, train_x.shape[0])

                batch_x = train_x[my_start:my_end]
                batch_y = train_y[my_start:my_end]

                if randcrops:
                    batch_x = impreproc.random_crops(batch_x, cropsize,
                                                     randmirrors)

                feed_dict = {
                    model.v.input_images: batch_x,
                    model.v.input_labels: batch_y,
                    K.learning_phase(): 1,
                }
                feed_dict.update(extra_feed_dict)

                fetch_dict = {'train_step': train_step}
                fetch_dict.update(model.trackable_and_update_dict)

                if args.output and do_log_train(buddy.epoch, buddy.train_iter,
                                                ii):
                    if param_histogram_summaries is not None:
                        fetch_dict.update({
                            'param_histogram_summaries':
                            param_histogram_summaries
                        })
                    if train_histogram_summaries is not None:
                        fetch_dict.update({
                            'train_histogram_summaries':
                            train_histogram_summaries
                        })
                    if train_scalar_summaries is not None:
                        fetch_dict.update(
                            {'train_scalar_summaries': train_scalar_summaries})

                with WithTimer('(worker %d) sess.run train iter' % my_rank,
                               quiet=not args.verbose):
                    result_train = sess_run_dict(sess,
                                                 fetch_dict,
                                                 feed_dict=feed_dict)

                buddy.note_weighted_list(
                    batch_x.shape[0],
                    model.trackable_names,
                    [result_train[k] for k in model.trackable_names],
                    prefix='train_')

                if do_log_train(buddy.epoch, buddy.train_iter, ii):
                    print('%3d (worker %d) train: %s (%.3gs/i)' %
                          (buddy.train_iter, my_rank,
                           buddy.epoch_mean_pretty_re(
                               '^train_', style=train_style), toc2()))
                    if args.output:
                        if param_histogram_summaries is not None:
                            hist_summary_str = result_train[
                                'param_histogram_summaries']
                            writer.add_summary(hist_summary_str,
                                               buddy.train_iter)
                        if train_histogram_summaries is not None:
                            hist_summary_str = result_train[
                                'train_histogram_summaries']
                            writer.add_summary(hist_summary_str,
                                               buddy.train_iter)
                        if train_scalar_summaries is not None:
                            scalar_summary_str = result_train[
                                'train_scalar_summaries']
                            writer.add_summary(scalar_summary_str,
                                               buddy.train_iter)
                        log_scalars(writer,
                                    buddy.train_iter, {
                                        'batch_%s' % name: value
                                        for name, value in buddy.last_list_re(
                                            '^train_')
                                    },
                                    prefix='buddy')

                if ii > 0 and ii % 100 == 0:
                    print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                        ii, toc3() / 100)
                    tic3()

                buddy.inc_train_iter()  # after finished training a mini-batch

        buddy.inc_epoch()  # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer,
                buddy.train_iter, {
                    'mean_%s' % name: value
                    for name, value in buddy.epoch_mean_list_re('^train_')
                },
                prefix='buddy')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^val_',
                                                            style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^train_',
                                                            style=train_style))

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output:
        writer.close()  # Flush and close
Ejemplo n.º 3
0
def main():
    parser = make_standard_parser(
        'Train a GAN model on simple square images or Clevr two-object color images',
        arch_choices=arch_choices,
        skip_train=True,
        skip_val=True)
    parser.add_argument('--z_dim',
                        type=int,
                        default=10,
                        help='Dimension of noise vector')
    parser.add_argument('--lr2',
                        type=float,
                        default=None,
                        help='learning rate for generator')
    parser.add_argument('--feature_match',
                        '-fm',
                        action='store_true',
                        help='use feature matching loss for generator.')
    parser.add_argument(
        '--feature_match_loss_weight',
        '-fmalpha',
        type=float,
        default=1.0,
        help='weight on the feature matching loss for generator.')
    parser.add_argument(
        '--pairedz',
        action='store_true',
        help='If True, pair the same z with a training batch each epoch')
    parser.add_argument(
        '--eval-train-every',
        type=int,
        default=0,
        help='evaluate whole training set every N epochs. 0 to disable.')

    args = parser.parse_args()

    args.skipval = True

    minibatch_size = args.minibatch
    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)
    evaltrain_style = '' if args.nocolor or args.eval_train_every <= 0 else colorama.Fore.CYAN

    black_divider = True if args.arch.startswith('clevr') else False

    # Get a TF session and set numpy and TF seeds
    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. LOAD DATA
    if args.arch.startswith('simple'):
        fd = h5py.File('data/rectangle_4_uniform.h5', 'r')
        train_x = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_imagegray'],
                         dtype=float) / 255.0  # shape (768, 64, 64, 1)
        train_x = np.concatenate((train_x, val_x),
                                 axis=0)  # shape (3136, 64, 64, 1)

    elif args.arch.startswith('clevr'):
        (train_x, val_x) = load_sort_of_clevr()
        # shape (50000, 64, 64, 3)
        train_x = np.concatenate((train_x, val_x), axis=0)

    else:
        raise Exception('Unknown network architecture: %s' % args.arch)

    print 'Train data loaded: {} images, size {}'.format(
        train_x.shape[0], train_x.shape[1:])
    #print 'Val data loaded: {} images, size {}'.format(val_x.shape[0], val_x.shape[1:])

    #print 'Label dimension: {}'.format(val_y.shape[1:])

    # 1. CREATE MODEL
    assert len(train_x.shape) == 4, "image data must be of 4 dimensions"
    image_h, image_w, image_c = train_x.shape[1], train_x.shape[
        2], train_x.shape[3]

    model = build_model(args, image_h, image_w, image_c)

    print 'All model weights:'
    summarize_weights(model.trainable_weights)
    print 'Model summary:'
    # model.summary()      # TOREPLACE
    print 'Another model summary:'
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    lr_gen = args.lr2 if args.lr2 else args.lr

    if args.opt == 'sgd':
        d_opt = tf.train.MomentumOptimizer(args.lr, args.mom)
        g_opt = tf.train.MomentumOptimizer(lr_gen, args.mom)
    elif args.opt == 'rmsprop':
        d_opt = tf.train.RMSPropOptimizer(args.lr, momentum=args.mom)
        g_opt = tf.train.RMSPropOptimizer(lr_gen, momentum=args.mom)
    elif args.opt == 'adam':
        d_opt = tf.train.AdamOptimizer(args.lr, args.beta1, args.beta2)
        g_opt = tf.train.AdamOptimizer(lr_gen, args.beta1, args.beta2)

    # Optimize w.r.t all trainable params in the model

    all_vars = model.trainable_variables
    d_vars = [var for var in all_vars if 'discriminator' in var.name]
    g_vars = [var for var in all_vars if 'generator' in var.name]

    d_grads_and_vars = d_opt.compute_gradients(
        model.d_loss, d_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    d_train_step = d_opt.apply_gradients(d_grads_and_vars)
    g_grads_and_vars = g_opt.compute_gradients(
        model.g_loss, g_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    g_train_step = g_opt.apply_gradients(g_grads_and_vars)

    hist_summaries_traintest(model.d_real_logits, model.d_fake_logits)

    add_grads_and_vars_hist_summaries(d_grads_and_vars)
    add_grads_and_vars_hist_summaries(g_grads_and_vars)
    image_summaries_traintest(model.fake_images)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running
    # BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy(pretty_replaces=[('evaltrain_', ''), (
            'eval', '')]) if args.eval_train_every > 0 else StatsBuddy()

    buddy.tic()  # call if new run OR resumed run

    tf.global_variables_initializer().run()

    # 4. SETUP TENSORBOARD LOGGING

    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN
    train_iters = (train_x.shape[0]) // minibatch_size
    if not args.skipval:
        val_iters = (val_x.shape[0]) // minibatch_size

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    # 2. use same noise, eval on 100 samples and save G(z),
    np.random.seed()
    eval_batch_size = 100
    eval_z = np.random.uniform(-1, 1, size=(eval_batch_size, args.z_dim))

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        def do_log_params(ep, it, ii):
            return True

        def do_log_val(ep, it, ii):
            return True

        def do_log_train(ep, it, ii):
            return (it < train_iters and it & it - 1 == 0
                    or it >= train_iters and it % train_iters == 0
                    )  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch, buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Evaluate generator by showing random generated results
        #    Evaluate descriminator by showing seeing correct rate on generated and real (hold-out) results
        #assert(args.skipval), "only support training now"

        if not args.skipval:
            tic2()
            # use different noise, eval on larger number of samples and get
            # correct rate
            np.random.seed()
            val_z = np.random.uniform(-1, 1, size=(val_x.shape[0], args.z_dim))

            with WithTimer('sess.run val iter', quiet=not args.verbose):
                feed_dict = {
                    model.input_images: val_x,
                    model.input_noise: val_z,
                    learning_phase(): 0
                }

                if 'input_labels' in model.named_keys():
                    feed_dict.update({model.input_labels: val_y})

                val_corr_fake_bn0, val_corr_real_bn0 = sess.run(
                    [model.correct_fake, model.correct_real],
                    feed_dict=feed_dict)

                feed_dict[learning_phase()] = 1
                val_corr_fake_bn1, val_corr_real_bn1 = sess.run(
                    [model.correct_fake, model.correct_real],
                    feed_dict=feed_dict)

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                fetch_dict = {}
                if test_image_summaries is not None:
                    fetch_dict.update(
                        {'test_image_summaries': test_image_summaries})
                if test_scalar_summaries is not None:
                    fetch_dict.update(
                        {'test_scalar_summaries': test_scalar_summaries})
                if test_histogram_summaries is not None:
                    fetch_dict.update(
                        {'test_histogram_summaries': test_histogram_summaries})
                if fetch_dict:
                    summary_strs = sess_run_dict(sess,
                                                 fetch_dict,
                                                 feed_dict=feed_dict)

            buddy.note_list([
                'correct_real_bn0', 'correct_fake_bn0', 'correct_real_bn1',
                'correct_fake_bn1'
            ], [
                val_corr_real_bn0, val_corr_fake_bn0, val_corr_real_bn1,
                val_corr_fake_bn1
            ],
                            prefix='val_')

            print(
                '%3d (ep %d) val: %s (%.3gs/ep)' %
                (buddy.train_iter, buddy.epoch,
                 buddy.epoch_mean_pretty_re('^val_', style=val_style), toc2()))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'mean_%s' % name: value
                        for name, value in buddy.epoch_mean_list_re('^val_')
                    },
                    prefix='buddy')

                if test_image_summaries is not None:
                    image_summary_str = summary_strs['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = summary_strs['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = summary_strs['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # In addition, evalutate 1000 more images
        np.random.seed()
        eval_more = np.random.uniform(-1, 1, size=(1000, args.z_dim))
        feed_dict2 = {
            # (100,-) generated outside of loop to keep the same every round
            model.input_noise:
            eval_z,
            learning_phase():
            0
        }

        eval_samples_bn0 = sess.run(model.fake_images, feed_dict=feed_dict2)

        feed_dict2[learning_phase()] = 1
        eval_samples_bn1 = sess.run(model.fake_images, feed_dict=feed_dict2)

        # feed in 10 times because coordconv cannot handle too big of a batch
        for cc in range(10):
            eval_z2 = eval_more[cc * 100:(cc + 1) * 100, :]
            _eval_more_samples = sess.run(
                model.fake_images,
                feed_dict={
                    model.input_noise: eval_z2,  # (1000,-)
                    learning_phase(): 0
                })
            eval_more_samples = _eval_more_samples if cc == 0 else np.concatenate(
                (eval_more_samples, _eval_more_samples), axis=0)

        if args.output:
            mkdir_p('{}/fake_images'.format(args.output))
            # eval_samples_bn*: e.g. (100, 64, 64, 3)
            save_images(eval_samples_bn0, [10, 10],
                        '{}/fake_images/g_out_bn0_epoch_{}_iter_{}.png'.format(
                            args.output, buddy.epoch, buddy.train_iter),
                        black_divider=black_divider)
            save_images(eval_samples_bn1, [10, 10],
                        '{}/fake_images/g_out_bn1_epoch_{}.png'.format(
                            args.output, buddy.epoch),
                        black_divider=black_divider)
            save_average_image(
                eval_more_samples,
                '{}/fake_images/g_out_averaged_epoch_{}_iter_{}.png'.format(
                    args.output, buddy.epoch, buddy.train_iter))

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            snap_end = buddy.epoch == args.epochs
            if snap_intermed or snap_end:
                # Snapshot
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open(
                        '%s/%s_misc_%04d.pkl.gz' %
                    (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)
                # snapshot sampled images too
                ff = h5py.File(
                    '%s/sampled_images_%04d.h5' % (args.output, buddy.epoch),
                    'w')
                ff.create_dataset('eval_samples_bn0', data=eval_samples_bn0)
                ff.create_dataset('eval_samples_bn1', data=eval_samples_bn1)
                ff.create_dataset('eval_z', data=eval_z)
                ff.create_dataset('eval_z_more', data=eval_more)
                ff.create_dataset('eval_more_samples', data=eval_more_samples)
                ff.close()

        # 2. Possiby evaluate the training set
        if args.eval_train_every > 0:
            if buddy.epoch % args.eval_train_every == 0:
                tic2()
                for ii in xrange(train_iters):
                    start_idx = ii * minibatch_size
                    if args.pairedz:
                        np.random.seed(args.seed + ii)
                    else:
                        np.random.seed()
                    batch_z = np.random.uniform(-1,
                                                1,
                                                size=(minibatch_size,
                                                      args.z_dim))

                    batch_x = train_x[start_idx:start_idx + minibatch_size]
                    batch_y = train_y[start_idx:start_idx + minibatch_size]

                    feed_dict = {
                        model.input_images: batch_x,
                        # model.input_labels: batch_y,
                        model.input_noise: batch_z,
                        learning_phase(): 0,
                    }

                    if 'input_labels' in model.named_keys():
                        feed_dict.update({model.input_labels: val_y})

                    fetch_dict = model.trackable_dict()
                    result_eval_train = sess_run_dict(sess,
                                                      fetch_dict,
                                                      feed_dict=feed_dict)
                    buddy.note_weighted_list(
                        batch_x.shape[0],
                        model.trackable_names(), [
                            result_eval_train[k]
                            for k in model.trackable_names()
                        ],
                        prefix='evaltrain_bn0_')

                    feed_dict = {
                        model.input_images: batch_x,
                        # model.input_labels: batch_y,
                        model.input_noise: batch_z,
                        learning_phase(): 1,
                    }
                    if 'input_labels' in model.named_keys():
                        feed_dict.update({model.input_labels: val_y})

                    result_eval_train = sess_run_dict(sess,
                                                      fetch_dict,
                                                      feed_dict=feed_dict)
                    buddy.note_weighted_list(
                        batch_x.shape[0],
                        model.trackable_names(), [
                            result_eval_train[k]
                            for k in model.trackable_names()
                        ],
                        prefix='evaltrain_bn1_')

                    if args.output:
                        log_scalars(writer,
                                    buddy.train_iter, {
                                        'batch_%s' % name: value
                                        for name, value in buddy.last_list_re(
                                            '^evaltrain_bn0_')
                                    },
                                    prefix='buddy')
                        log_scalars(writer,
                                    buddy.train_iter, {
                                        'batch_%s' % name: value
                                        for name, value in buddy.last_list_re(
                                            '^evaltrain_bn1_')
                                    },
                                    prefix='buddy')
                if args.output:
                    log_scalars(writer,
                                buddy.epoch, {
                                    'mean_%s' % name: value
                                    for name, value in
                                    buddy.epoch_mean_list_re('^evaltrain_bn0_')
                                },
                                prefix='buddy')
                    log_scalars(writer,
                                buddy.epoch, {
                                    'mean_%s' % name: value
                                    for name, value in
                                    buddy.epoch_mean_list_re('^evaltrain_bn1_')
                                },
                                prefix='buddy')

                print('%3d (ep %d) evaltrain: %s (%.3gs/ep)' %
                      (buddy.train_iter, buddy.epoch,
                       buddy.epoch_mean_pretty_re(
                           '^evaltrain_bn0_', style=evaltrain_style), toc2()))
                print('%3d (ep %d) evaltrain: %s (%.3gs/ep)' %
                      (buddy.train_iter, buddy.epoch,
                       buddy.epoch_mean_pretty_re(
                           '^evaltrain_bn1_', style=evaltrain_style), toc2()))

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break  # Extra pass at end: just report val stats and skip training

        # 3. Train on training set

        if args.shuffletrain:
            train_order = np.random.permutation(train_x.shape[0])
            train_order2 = np.random.permutation(train_x.shape[0])
        tic3()
        for ii in xrange(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            if args.pairedz:
                np.random.seed(args.seed + ii)
            else:
                np.random.seed()

            batch_z = np.random.uniform(-1,
                                        1,
                                        size=(minibatch_size, args.z_dim))

            if args.shuffletrain:
                #batch_x = train_x[train_order[start_idx:start_idx + minibatch_size]]
                batch_x = train_x[sorted(train_order[start_idx:start_idx +
                                                     minibatch_size].tolist())]
                if args.feature_match:
                    assert args.shuffletrain, "feature matching loss requires shuffle train"
                    batch_x2 = train_x[sorted(
                        train_order2[start_idx:start_idx +
                                     minibatch_size].tolist())]
                if 'input_labels' in model.named_keys():
                    batch_y = train_y[sorted(
                        train_order[start_idx:start_idx +
                                    minibatch_size].tolist())]
            else:
                batch_x = train_x[start_idx:start_idx + minibatch_size]
                if 'input_labels' in model.named_keys():
                    batch_y = train_y[start_idx:start_idx + minibatch_size]

            feed_dict = {
                model.input_images: batch_x,
                # model.input_labels: batch_y,
                model.input_noise: batch_z,
                learning_phase(): 1,
            }

            if 'input_labels' in model.named_keys():
                feed_dict.update({model.input_labels: batch_y})
            if 'input_images2' in model.named_keys():
                feed_dict.update({model.input_images2: batch_x2})

            fetch_dict = model.trackable_and_update_dict()

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update({
                        'train_histogram_summaries':
                        train_histogram_summaries
                    })
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(sess,
                                             fetch_dict,
                                             feed_dict=feed_dict)

                # if result_train['d_loss'] < result_train['g_loss']:
                #    #print 'Only train G'
                #    sess.run(g_train_step, feed_dict=feed_dict)
                # else:
                #    #print 'Train both D and G'
                #    sess.run(d_train_step, feed_dict=feed_dict)
                #    sess.run(g_train_step, feed_dict=feed_dict)
                #    sess.run(g_train_step, feed_dict=feed_dict)
                sess.run(d_train_step, feed_dict=feed_dict)
                sess.run(g_train_step, feed_dict=feed_dict)
                sess.run(g_train_step, feed_dict=feed_dict)

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                buddy.note_weighted_list(
                    batch_x.shape[0],
                    model.trackable_names(),
                    [result_train[k] for k in model.trackable_names()],
                    prefix='train_')
                print('[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                      (buddy.train_iter, buddy.epoch, args.epochs,
                       buddy.epoch_mean_pretty_re('^train_',
                                                  style=train_style), toc2()))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train[
                        'train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'batch_%s' % name: value
                        for name, value in buddy.last_list_re('^train_')
                    },
                    prefix='buddy')

            if ii > 0 and ii % 100 == 0:
                print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                    ii, toc3() / 100)
                tic3()

            buddy.inc_train_iter()  # after finished training a mini-batch

        buddy.inc_epoch()  # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer,
                buddy.train_iter, {
                    'mean_%s' % name: value
                    for name, value in buddy.epoch_mean_list_re('^train_')
                },
                prefix='buddy')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^val_',
                                                            style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^train_',
                                                            style=train_style))

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output:
        writer.close()  # Flush and close
Ejemplo n.º 4
0
def main():
    parser = make_standard_parser(
        'Coordconv',
        arch_choices=arch_choices,
        skip_train=True,
        skip_val=True)
    # re-add train and val h5s as optional
    parser.add_argument('--data_h5', type=str,
                        default='./data/rectangle_4_uniform.h5',
                        help='data file in hdf5.')
    parser.add_argument('--x_dim', type=int, default=64,
                        help='x dimension of the output image')
    parser.add_argument('--y_dim', type=int, default=64,
                        help='y dimension of the output image')
    parser.add_argument('--lrpolicy', type=str, default='constant',
                        choices=lr_policy_choices, help='LR policy.')
    parser.add_argument('--lrstepratio', type=float,
                        default=.1, help='LR policy step ratio.')
    parser.add_argument('--lrmaxsteps', type=int, default=5,
                        help='LR policy step ratio.')
    parser.add_argument('--lrstepevery', type=int, default=50,
                        help='LR policy step ratio.')
    parser.add_argument('--filter_size', '-fs', type=int, default=3,
                        help='filter size in deconv network')
    parser.add_argument('--channel_mul', '-mul', type=int, default=2,
        help='Deconv model channel multiplier to make bigger models')
    parser.add_argument('--use_mse_loss', '-mse', action='store_true',
                        help='use mse loss instead of cross entropy')
    parser.add_argument('--use_sigm_loss', '-sig', action='store_true',
                        help='use sigmoid loss instead of cross entropy')
    parser.add_argument('--interm_loss', '-interm', default=None,
        choices=(None, 'softmax', 'mse'),
        help='add intermediate loss to end-to-end painter model')
    parser.add_argument('--no_softmax', '-nosfmx', action='store_true',
                        help='Remove softmax sharpening layer in model')

    args = parser.parse_args()

    if args.lrpolicy == 'step':
        lr_policy = LRPolicyStep(args)
    elif args.lrpolicy == 'valstep':
        lr_policy = LRPolicyValStep(args)
    else:
        lr_policy = LRPolicyConstant(args)

    minibatch_size = args.minibatch
    train_style, val_style = (
        '', '') if args.nocolor else (
        colorama.Fore.BLUE, colorama.Fore.MAGENTA)

    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. Load data or generate data on the fly
    print 'Loading data: {}'.format(args.data_h5)

    if args.arch in ['deconv_classification',
                     'coordconv_classification',
                     'upsample_conv_coords',
                     'upsample_coordconv_coords']:

        # option a: generate data on the fly
        #data = list(itertools.product(range(args.x_dim),range(args.y_dim)))
        # random.shuffle(data)

        #train_test_split = .8
        #val_reps = int(args.x_dim * args.x_dim * train_test_split) // minibatch_size
        #val_size = val_reps * minibatch_size
        #train_end = args.x_dim * args.x_dim - val_size
        #train_x, val_x = np.array(data[:train_end]).astype('int'), np.array(data[train_end:]).astype('int')
        #train_y, val_y = None, None
        #DATA_GEN_ON_THE_FLY = True

        # option b: load the data
        fd = h5py.File(args.data_h5, 'r')

        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_onehots'], dtype=float)  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_onehots'], dtype=float)  # shape (768, 64, 64, 1)
        DATA_GEN_ON_THE_FLY = False

        # number of image channels
        image_c = train_y.shape[-1] if train_y is not None and len(train_y.shape) == 4 else 1

    elif args.arch == 'conv_onehot_image':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(
            fd['train_onehots'],
            dtype=int)  # shape (2368, 64, 64, 1)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(
            fd['val_onehots'],
            dtype=float)  # shape (768, 64, 64, 1)
        val_y = np.array(fd['val_imagegray'], dtype=float) / \
            255.0  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    elif args.arch == 'deconv_rendering':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_imagegray'], dtype=float) / \
            255.0  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    elif args.arch == 'conv_regressor' or args.arch == 'coordconv_regressor':
        fd = h5py.File(args.data_h5, 'r')
        train_y = np.array(
            fd['train_normalized_locations'],
            dtype=float)  # shape (2368, 2)
        # /255.0 # shape (2368, 64, 64, 1)
        train_x = np.array(fd['train_onehots'], dtype=float)
        val_y = np.array(
            fd['val_normalized_locations'],
            dtype=float)  # shape (768, 2)
        val_x = np.array(
            fd['val_onehots'],
            dtype=float)  # shape (768, 64, 64, 1)

        image_c = train_x.shape[-1]

    elif args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_imagegray'], dtype=float) / 255.0  # shape (768, 64, 64, 1)

        # add one-hot anyways to track accuracy etc. even if not used in loss
        train_onehot = np.array(
            fd['train_onehots'],
            dtype=int)  # shape (2368, 64, 64, 1)
        val_onehot = np.array(
            fd['val_onehots'],
            dtype=int)  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    train_size = train_x.shape[0]
    val_size = val_x.shape[0]

    # 1. CREATE MODEL
    input_coords = tf.placeholder(
        shape=(None,2),
        dtype='float32',
        name='input_coords')  # cast later in model into float
    input_onehot = tf.placeholder(
        shape=(None, args.x_dim, args.y_dim, 1),
        dtype='float32',
        name='input_onehot')
    input_images = tf.placeholder(
        shape=(None, args.x_dim, args.y_dim, image_c),
        dtype='float32',
        name='input_images')

    if args.arch == 'deconv_classification':
        model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim,
                              fs=args.filter_size, mul=args.channel_mul,
                              onthefly=DATA_GEN_ON_THE_FLY,
                              use_mse_loss=args.use_mse_loss,
                              use_sigm_loss=args.use_sigm_loss)

        model.a('input_coords', input_coords)

        if not DATA_GEN_ON_THE_FLY:
            model.a('input_onehot', input_onehot)

        model([input_coords]) if DATA_GEN_ON_THE_FLY else model([input_coords, input_onehot])

    if args.arch == 'conv_regressor':
        regress_type = 'conv_uniform' if 'uniform' in args.data_h5 else 'conv_quarant'
        model = ConvRegressor(l2=args.l2, mul=args.channel_mul,
                              _type=regress_type)
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        # call model on inputs
        model([input_onehot, input_coords])

    if args.arch == 'coordconv_regressor':
        model = ConvRegressor(l2=args.l2, mul=args.channel_mul,
                              _type='coordconv')
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        # call model on inputs
        model([input_onehot, input_coords])

    if args.arch == 'conv_onehot_image':
        model = ConvImagePainter(l2=args.l2, fs=args.filter_size, mul=args.channel_mul,
            use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss,
            version='working')
            # version='simple') # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)
        # call model on inputs
        model([input_onehot, input_images])

    if args.arch == 'deconv_rendering':
        model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim,
                              fs=args.filter_size, mul=args.channel_mul,
                              onthefly=False,
                              use_mse_loss=args.use_mse_loss,
                              use_sigm_loss=args.use_sigm_loss)
        model.a('input_coords', input_coords)
        model.a('input_images', input_images)
        # call model on inputs
        model([input_coords, input_images])

    elif args.arch == 'coordconv_classification':
        model = CoordConvPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            include_r=False,
            mul=args.channel_mul,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss)

        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)

        model([input_coords, input_onehot])
        #raise Exception('Not implemented yet')

    elif args.arch == 'coordconv_rendering':
        model = CoordConvImagePainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            include_r=False,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            interm_loss=args.interm_loss,
            no_softmax=args.no_softmax,
            version='working')
        # version='simple') # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)

        # always input three things to calculate relevant metrics
        model([input_coords, input_onehot, input_images])
    elif args.arch == 'deconv_bottleneck':
        model = DeconvBottleneckPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            interm_loss=args.interm_loss,
            no_softmax=args.no_softmax,
            version='working')  # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)

        # always input three things to calculate relevant metrics
        model([input_coords, input_onehot, input_images])

    elif args.arch == 'upsample_conv_coords' or args.arch == 'upsample_coordconv_coords':
        _coordconv = True if args.arch == 'upsample_coordconv_coords' else False
        model = UpsampleConvPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            coordconv=_coordconv)
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model([input_coords, input_onehot])

    print 'All model weights:'
    summarize_weights(model.trainable_weights)
    #print 'Model summary:'
    print 'Another model summary:'
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    # a placeholder for dynamic learning rate
    input_lr = tf.placeholder(tf.float32, shape=[])
    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(input_lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2)

    grads_and_vars = opt.compute_gradients(
        model.loss,
        model.trainable_weights,
        gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    train_step = opt.apply_gradients(grads_and_vars)
    # added to train_ and param_ collections
    add_grads_and_vars_hist_summaries(grads_and_vars)

    summarize_opt(opt)
    print 'LR Policy:', lr_policy

    # add_grad_summaries(grads_and_vars)
    if not args.arch.endswith('regressor'):
        image_summaries_traintest(model.logits)

    if 'input_onehot' in model.named_keys():
        image_summaries_traintest(model.input_onehot)
    if 'input_images' in model.named_keys():
        image_summaries_traintest(model.input_images)
    if 'prob' in model.named_keys():
        image_summaries_traintest(model.prob)
    if 'center_prob' in model.named_keys():
        image_summaries_traintest(model.center_prob)
    if 'center_logits' in model.named_keys():
        image_summaries_traintest(model.center_logits)
    if 'pixelwise_prob' in model.named_keys():
        image_summaries_traintest(model.pixelwise_prob)
    if 'center_logits' in model.named_keys():
        image_summaries_traintest(model.center_logits)
    if 'sharpened_logits' in model.named_keys():
        image_summaries_traintest(model.sharpened_logits)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running
    # BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (
        args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()

    buddy.tic()    # call if new run OR resumed run

    # Check if special layers are initialized right
    #last_layer_w = [var for var in tf.global_variables() if 'painting_layer/kernel:0' in var.name][0]
    #last_layer_b = [var for var in tf.global_variables() if 'painting_layer/bias:0' in var.name][0]

    # Initialize any missed vars (e.g. optimization momentum, ... if not
    # loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(
        uninitialized_vars, 'init_missed_vars')
    sess.run(init_missed_vars)
    # Print warnings about any TF vs. Keras shape mismatches
    # warn_misaligned_shapes(model)
    # Make sure all variables, which are model variables, have been
    # initialized (e.g. model params and model running BN means)
    tf_assert_all_init(sess)
    # tf.global_variables_initializer().run()

    # 4. SETUP TENSORBOARD LOGGING with tf.summary.merge

    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN

    train_iters = (train_size) // minibatch_size + \
        int(train_size % minibatch_size > 0)
    if not args.skipval:
        val_iters = (val_size) // minibatch_size + \
            int(val_size % minibatch_size > 0)

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        def do_log_params(ep, it, ii): return True
        def do_log_val(ep, it, ii): return True

        def do_log_train(
            ep,
            it,
            ii): return (
            it < train_iters and it & it -
            1 == 0 or it >= train_iters and it %
            train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch,
                buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Forward test on validation set
        if not args.skipval:
            feed_dict = {learning_phase(): 0}
            if 'input_coords' in model.named_keys():
                val_coords = val_y if args.arch.endswith(
                    'regressor') else val_x
                feed_dict.update({model.input_coords: val_coords})

            if 'input_onehot' in model.named_keys():
                # if 'val_onehot' not in locals():
                if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck':
                    if args.arch == 'conv_onehot_image' or args.arch.endswith('regressor'):
                        val_onehot = val_x
                    else:
                        val_onehot = val_y
                feed_dict.update({
                    model.input_onehot: val_onehot,
                })
            if 'input_images' in model.named_keys():
                feed_dict.update({
                    model.input_images: val_images,
                })

            fetch_dict = model.trackable_dict()

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                if test_image_summaries is not None:
                    fetch_dict.update(
                        {'test_image_summaries': test_image_summaries})
                if test_scalar_summaries is not None:
                    fetch_dict.update(
                        {'test_scalar_summaries': test_scalar_summaries})
                if test_histogram_summaries is not None:
                    fetch_dict.update(
                        {'test_histogram_summaries': test_histogram_summaries})

            with WithTimer('sess.run val iter', quiet=not args.verbose):
                result_val = sess_run_dict(
                    sess, fetch_dict, feed_dict=feed_dict)

            buddy.note_list(
                model.trackable_names(), [
                    result_val[k] for k in model.trackable_names()], prefix='val_')
            print (
                '[%5d] [%2d/%2d] val: %s (%.3gs/i)' %
                (buddy.train_iter,
                 buddy.epoch,
                 args.epochs,
                 buddy.epoch_mean_pretty_re(
                     '^val_',
                     style=val_style),
                    toc2()))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer, buddy.train_iter, {
                        'mean_%s' %
                        name: value for name, value in buddy.epoch_mean_list_re('^val_')}, prefix='val')
                if test_image_summaries is not None:
                    image_summary_str = result_val['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = result_val['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = result_val['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            #snap_end = buddy.epoch == args.epochs
            snap_end = lr_policy.train_done(buddy)
            if snap_intermed or snap_end:
                # Snapshot network and buddy
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open('%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)
                # Snapshot evaluation data and metrics
                _, _ = evaluate_net(
                    args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess)

        lr = lr_policy.get_lr(buddy)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break   # Extra pass at end: just report val stats and skip training

        print '********* at epoch %d, LR is %g' % (buddy.epoch, lr)

        # 3. Train on training set
        if args.shuffletrain:
            train_order = np.random.permutation(train_size)
        tic3()
        for ii in xrange(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            end_idx = min(start_idx + minibatch_size, train_size)

            if args.shuffletrain:  # default true
                batch_x = train_x[sorted(
                    train_order[start_idx:end_idx].tolist())]
                if train_y is not None:
                    batch_y = train_y[sorted(
                        train_order[start_idx:end_idx].tolist())]
                # if 'train_onehot' in locals():
                if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
                    batch_onehot = train_onehot[sorted(
                        train_order[start_idx:end_idx].tolist())]
            else:
                batch_x = train_x[start_idx:end_idx]
                if train_y is not None:
                    batch_y = train_y[start_idx:end_idx]
                # if 'train_onehot' in locals():
                if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
                    batch_onehot = train_onehot[start_idx:end_idx]

            feed_dict = {learning_phase(): 1, input_lr: lr}
            if 'input_coords' in model.named_keys():
                batch_coords = batch_y if args.arch.endswith(
                    'regressor') else batch_x
                feed_dict.update({model.input_coords: batch_coords})
            if 'input_onehot' in model.named_keys():
                # if 'batch_onehot' not in locals():
                # if not (args.arch == 'coordconv_rendering' and
                # args.add_interm_loss):
                if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck':
                    if args.arch == 'conv_onehot_image' or args.arch.endswith(
                            'regressor'):
                        batch_onehot = batch_x
                    else:
                        batch_onehot = batch_y
                feed_dict.update({
                    model.input_onehot: batch_onehot,
                })
            if 'input_images' in model.named_keys():
                feed_dict.update({
                    model.input_images: batch_images,
                })

            fetch_dict = model.trackable_and_update_dict()

            fetch_dict.update({'train_step': train_step})

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update(
                        {'train_histogram_summaries': train_histogram_summaries})
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(
                    sess, fetch_dict, feed_dict=feed_dict)

            buddy.note_weighted_list(
                batch_x.shape[0], model.trackable_names(), [
                    result_train[k] for k in model.trackable_names()], prefix='train_')

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                print (
                    '[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                    (buddy.train_iter,
                     buddy.epoch,
                     args.epochs,
                     buddy.epoch_mean_pretty_re(
                         '^train_',
                         style=train_style),
                        toc2()))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train['train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer, buddy.train_iter, {
                        'batch_%s' %
                        name: value for name, value in buddy.last_list_re('^train_')}, prefix='train')

            if ii > 0 and ii % 100 == 0:
                print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                    ii, toc3() / 100)
                tic3()

            buddy.inc_train_iter()   # after finished training a mini-batch

        buddy.inc_epoch()   # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer, buddy.train_iter, {
                    'mean_%s' %
                    name: value for name, value in buddy.epoch_mean_list_re('^train_')}, prefix='train')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch,
                                 buddy.train_iter,
                                 buddy.epoch_mean_pretty_re(
                                     '^val_',
                                     style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch,
                                 buddy.train_iter,
                                 buddy.epoch_mean_pretty_re(
                                     '^train_',
                                     style=train_style))

    print '\nEnd of training. Saving evaluation results on whole train and val set.'

    final_tr_metrics, final_va_metrics = evaluate_net(
        args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess)

    print '\nFinal evaluation on whole train and val'
    for name, value in final_tr_metrics.iteritems():
        print 'final_stats_eval train_%s %g' % (name, value)
    for name, value in final_va_metrics.iteritems():
        print 'final_stats_eval val_%s %g' % (name, value)

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output:
        writer.close()   # Flush and close
Ejemplo n.º 5
0
def main():
    lr_policy_choices = ('constant', 'step', 'valstep')

    parser = make_standard_parser('Region Proposal Net',
                                  arch_choices=arch_choices,
                                  skip_train=True,
                                  skip_val=True)
    parser.add_argument(
        '--num',
        '-N',
        type=int,
        default=2,
        help='Load the Field-of-MNIST dataset with NUM digits per image.')
    parser.add_argument('--lrpolicy',
                        type=str,
                        default='constant',
                        choices=lr_policy_choices,
                        help='LR policy.')
    parser.add_argument('--lrstepratio',
                        type=float,
                        default=.1,
                        help='LR policy step ratio.')
    parser.add_argument('--lrmaxsteps',
                        type=int,
                        default=5,
                        help='LR policy step ratio.')
    parser.add_argument('--lrstepevery',
                        type=int,
                        default=50,
                        help='LR policy step ratio.')
    parser.add_argument('--clip',
                        action='store_true',
                        help='clip predicted and ground truth boxes.')
    parser.add_argument('--same',
                        action='store_true',
                        help='Use `same` filter instead of `valid` in conv.')
    parser.add_argument('--showbox',
                        action='store_true',
                        help='show moved box during training.')

    args = parser.parse_args()

    if args.lrpolicy == 'step':
        lr_policy = LRPolicyStep(args)
    elif args.lrpolicy == 'valstep':
        lr_policy = LRPolicyValStep(args)
    else:
        lr_policy = LRPolicyConstant(args)

    minibatch_size = 1
    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)

    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. Load data
    #train_ims, train_pos, train_class, valid_ims, valid_pos, valid_class, _, _, _ = load_tvt_n_per_field(args.num)
    #train_ims = train_ims[:5000]     # (5000, 64, 64, 1)
    #train_pos = train_pos[:5000]     # (5000, 2, 4)
    #valid_ims = valid_ims[:1000]
    #valid_pos = valid_pos[:1000]
    #_ims, _pos, _class, _, _, _, _, _, _ = load_tvt_n_per_field_centercrop(args.num)

    ff = h5py.File('data/field_of_mnist_cropped_64x64_5objs.h5', 'r')

    train_ims = np.array(ff['train_ims'])  # (9000, 64, 64, 1)
    train_pos = np.array(
        ff['train_pos'])  # (9000, 5, 4), parts of boxes may be out of canvas
    train_class = np.array(ff['train_class'])  # (9000, 5)
    valid_ims = np.array(ff['valid_ims'])  # (1000, 64, 64, 1)
    valid_pos = np.array(
        ff['valid_pos'])  # (1000, 5, 4), parts of boxes may be out of canvas
    valid_class = np.array(ff['valid_class'])  # (1000, 5)

    ff.close()

    im_h, im_w, im_c = train_ims.shape[1], train_ims.shape[2], train_ims.shape[
        3]
    train_size = train_ims.shape[0]
    val_size = valid_ims.shape[0]

    print(('Data loaded:\n\timage shape: {}x{}x{}'.format(im_h, im_w, im_c)))
    print(('\ttrain size: {}\n\ttest size: {}'.format(train_size, val_size)))
    print(('\tnumber of objects per image: {}'.format(train_pos.shape[1])))

    ####################
    # RPN prameters
    ####################
    rpn_params = RPNParams(anchors=np.array([(15, 15), (20, 20), (25, 25),
                                             (15, 20), (20, 25), (20, 15),
                                             (25, 20), (15, 25), (25, 15)]),
                           rpn_hidden_dim=32,
                           zero_box_conv=False,
                           weight_init_std=0.01,
                           anchor_scale=1.0)

    bsamp_params = BoxSamplerParams(hi_thresh=0.5,
                                    lo_thresh=0.1,
                                    sample_size=12)

    nms_params = NMSParams(
        nms_thresh=0.8,
        max_proposals=10,
    )

    # 1. CREATE MODEL

    input_images = tf.placeholder(shape=(None, im_h, im_w, im_c),
                                  dtype='float32',
                                  name='input_images')
    input_gtbox = tf.placeholder(shape=(train_pos.shape[1], 4),
                                 dtype='float32',
                                 name='input_gtbox')

    if args.arch == 'rpn_sampler':
        model = RegionProposalSampler(rpn_params,
                                      bsamp_params,
                                      nms_params,
                                      l2=args.l2,
                                      im_h=im_h,
                                      im_w=im_w,
                                      coordconv=False,
                                      clip=args.clip,
                                      filtersame=args.same)
    elif args.arch == 'coord_rpn_sampler':
        model = RegionProposalSampler(rpn_params,
                                      bsamp_params,
                                      nms_params,
                                      l2=args.l2,
                                      im_h=im_h,
                                      im_w=im_w,
                                      coordconv=True,
                                      clip=args.clip,
                                      filtersame=args.same)
    else:
        raise ValueError('Architecture {} unknown'.format(args.arch))

    if args.same:
        anchors = make_anchors_mnist_same(
            (16, 16), minibatch_size,
            rpn_params.anchors)  # (batch, 16, 16, 4k)
        input_anchors = tf.placeholder(shape=(16, 16,
                                              4 * rpn_params.num_anchors),
                                       dtype='float32',
                                       name='input_anchors')
    else:
        anchors = make_anchors_mnist((13, 13), minibatch_size,
                                     rpn_params.anchors)  # (batch, 13, 13, 4k)
        input_anchors = tf.placeholder(shape=(13, 13,
                                              4 * rpn_params.num_anchors),
                                       dtype='float32',
                                       name='input_anchors')
    anchors = anchors[0]

    model.a('input_images', input_images)
    model.a('input_anchors', input_anchors)
    model.a('input_gtbox', input_gtbox)

    model([input_images, input_anchors, input_gtbox])

    print('All model weights:')
    summarize_weights(model.trainable_weights)
    #print 'Model summary:'
    print('Another model summary:')
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    input_lr = tf.placeholder(
        tf.float32, shape=[])  # a placeholder for dynamic learning rate
    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(input_lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2)

    grads_and_vars = opt.compute_gradients(
        model.loss,
        model.trainable_weights,
        gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    train_step = opt.apply_gradients(grads_and_vars)
    add_grads_and_vars_hist_summaries(
        grads_and_vars)  # added to train_ and param_ collections

    summarize_opt(opt)
    print(('LR Policy:', lr_policy))

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()

    buddy.tic()  # call if new run OR resumed run

    # Check if special layers are initialized right
    #last_layer_w = [var for var in tf.global_variables() if 'painting_layer/kernel:0' in var.name][0]
    #last_layer_b = [var for var in tf.global_variables() if 'painting_layer/bias:0' in var.name][0]

    # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(uninitialized_vars,
                                                'init_missed_vars')
    sess.run(init_missed_vars)
    tf_assert_all_init(sess)

    # 4. SETUP TENSORBOARD LOGGING with tf.summary.merge

    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN

    train_iters = (train_size) // minibatch_size
    if not args.skipval:
        val_iters = (val_size) // minibatch_size

    if args.output:
        show_indices = np.random.permutation(val_size)[:9]
        mkdir_p('{}/figures'.format(args.output))

    if args.ipy:
        print('Embed: before train / val loop (Ctrl-D to continue)')
        embed()

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        do_log_params = lambda ep, it, ii: True
        do_log_val = lambda ep, it, ii: True
        do_log_train = lambda ep, it, ii: (
            it < train_iters and it & it - 1 == 0 or it >= train_iters and it %
            train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch, buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Forward test on validation set
        if not args.skipval:
            for ii in range(val_iters):
                tic2()
                start_idx = ii * minibatch_size
                end_idx = min(start_idx + minibatch_size, val_size)
                if not end_idx > start_idx:
                    continue

                feed_dict = {
                    model.input_images: valid_ims[start_idx:end_idx],
                    model.input_anchors: anchors,
                    model.input_gtbox: valid_pos[start_idx:end_idx][0],
                    learning_phase(): 0
                }

                fetch_dict = model.trackable_dict()

                if args.output and do_log_val(buddy.epoch, buddy.train_iter,
                                              0):
                    if test_image_summaries is not None:
                        fetch_dict.update(
                            {'test_image_summaries': test_image_summaries})
                    if test_scalar_summaries is not None:
                        fetch_dict.update(
                            {'test_scalar_summaries': test_scalar_summaries})
                    if test_histogram_summaries is not None:
                        fetch_dict.update({
                            'test_histogram_summaries':
                            test_histogram_summaries
                        })

                with WithTimer('sess.run val iter', quiet=not args.verbose):
                    result_val = sess_run_dict(sess,
                                               fetch_dict,
                                               feed_dict=feed_dict)

                ## DEBUG
                ## dynamic p_size and n_size, shouldn slightly very every sample
                #if ii > 0 and ii % 100 == 0:
                #    print 'VALIDATION --- '
                #    print sess.run(model.p_size, feed_dict=feed_dict)
                #    print sess.run(model.n_size, feed_dict=feed_dict)
                ## END DEBUG

                buddy.note_weighted_list(
                    minibatch_size,
                    model.trackable_names(),
                    [result_val[k] for k in model.trackable_names()],
                    prefix='val_')

            # Done all val set
            print(('[%5d] [%2d/%2d] val: %s (%.3gs/i)' %
                   (buddy.train_iter, buddy.epoch, args.epochs,
                    buddy.epoch_mean_pretty_re('^val_',
                                               style=val_style), toc2())))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'mean_%s' % name: value
                        for name, value in buddy.epoch_mean_list_re('^val_')
                    },
                    prefix='val')
                if test_image_summaries is not None:
                    image_summary_str = result_val['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = result_val['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = result_val['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # show some boxes
        if args.showbox:  #and (valid_losses[epoch]/previous_best < 1- args.thresh):
            show_indices = [55, 555, 678]
            for show_idx in show_indices:
                [pos_box, pos_score, neg_box, neg_score] = sess.run(
                    [
                        model.pos_box, model.pos_score, model.neg_box,
                        model.neg_score
                    ],
                    feed_dict={
                        model.input_images: valid_ims[show_idx:show_idx + 1],
                        model.input_anchors: anchors,
                        model.input_gtbox: valid_pos[show_idx],
                        learning_phase(): 0
                    })
                subplot(1, 3, show_indices.index(show_idx) + 1)
                #plot_boxes_pos_neg(valid_ims[show_idx], valid_pos[show_idx], pos_box, neg_box)
                plot_pos_boxes(valid_ims[show_idx],
                               valid_pos[show_idx],
                               pos_box,
                               pos_score,
                               showlabel=False)
            show()

        if args.output:
            switch_backend('Agg')
            plot_fetch_dict = {
                'pos_box': model.pos_box,
                'pos_score': model.pos_score,
                'neg_box': model.neg_box,
                'neg_score': model.neg_score,
                'nms_boxes': model.nms_boxes,
                'nms_scores': model.nms_scores,
            }

            #fig1, ax1 = subplots(3,3)  # plot train boxes
            #fig2, ax2 = subplots(3,3)  # plot test/nms boxes
            for cc, show_idx in enumerate(show_indices, 1):
                feed_dict = {
                    model.input_images: valid_ims[show_idx:show_idx + 1],
                    model.input_anchors: anchors,
                    model.input_gtbox: valid_pos[show_idx],
                    learning_phase(): 0
                }
                result_plots = sess_run_dict(sess,
                                             plot_fetch_dict,
                                             feed_dict=feed_dict)
                fig1 = figure(1)
                subplot(3, 3, cc)
                plot_boxes_pos_neg(valid_ims[show_idx], valid_pos[show_idx],
                                   result_plots['pos_box'],
                                   result_plots['neg_box'])
                fig2 = figure(2)
                subplot(3, 3, cc)
                #plot_pos_boxes(valid_ims[show_idx], valid_pos[show_idx], result_plots['nms_boxes'], result_plots['nms_scores'], showlabel=False)
                # normalize scores between 0 and 5, to be used as line width
                _score_as_lw = 5 * (result_plots['nms_scores'] -
                                    result_plots['nms_scores'].min()) / (
                                        result_plots['nms_scores'].max() -
                                        result_plots['nms_scores'].min())
                plot_pos_boxes_thickness(valid_ims[show_idx],
                                         valid_pos[show_idx],
                                         result_plots['nms_boxes'],
                                         result_plots['nms_scores'])

            fig1.set_size_inches(10, 10)
            fig1.savefig('{}/figures/pos_neg_train_box_epoch_{}.png'.format(
                args.output, buddy.epoch),
                         dpi=100)
            fig2.set_size_inches(10, 10)
            fig2.savefig('{}/figures/nms_test_box_epoch_{}.png'.format(
                args.output, buddy.epoch),
                         dpi=100)

            # plot test/nms boxes
            fig, _ = subplots()

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            #snap_end = buddy.epoch == args.epochs
            snap_end = lr_policy.train_done(buddy)
            if snap_intermed or snap_end:
                # Snapshot network and buddy
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print(('snappshotted model to', save_path))
                with gzip.open(
                        '%s/%s_misc_%04d.pkl.gz' %
                    (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)

        lr = lr_policy.get_lr(buddy)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print('Embed: at end of training (Ctrl-D to exit)')
                embed()
            break  # Extra pass at end: just report val stats and skip training

        print(('********* at epoch %d, LR is %g' % (buddy.epoch, lr)))

        # 3. Train on training set
        if args.shuffletrain:
            train_order = np.random.permutation(train_size)
        tic3()
        for ii in range(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            end_idx = min(start_idx + minibatch_size, train_size)

            if not end_idx > start_idx:
                continue

            if args.shuffletrain:  # default true
                batch_ims = train_ims[sorted(
                    train_order[start_idx:end_idx].tolist())]
                batch_pos = train_pos[sorted(
                    train_order[start_idx:end_idx].tolist())]
            else:
                batch_ims = train_ims[start_idx:end_idx]
                batch_pos = train_pos[start_idx:end_idx]

            feed_dict = {
                model.input_images: batch_ims,
                model.input_anchors: anchors,
                model.input_gtbox: batch_pos[0],
                learning_phase(): 1,
                input_lr: lr
            }

            fetch_dict = model.trackable_and_update_dict()

            fetch_dict.update({'train_step': train_step})

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update({
                        'train_histogram_summaries':
                        train_histogram_summaries
                    })
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(sess,
                                             fetch_dict,
                                             feed_dict=feed_dict)

            buddy.note_weighted_list(
                minibatch_size,
                model.trackable_names(),
                [result_train[k] for k in model.trackable_names()],
                prefix='train_')

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                print(('[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                       (buddy.train_iter, buddy.epoch, args.epochs,
                        buddy.epoch_mean_pretty_re(
                            '^train_', style=train_style), toc2())))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train[
                        'train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'batch_%s' % name: value
                        for name, value in buddy.last_list_re('^train_')
                    },
                    prefix='train')

            if ii > 0 and ii % 100 == 0:
                print((
                    '  %d: Average iteration time over last 100 train iters: %.3gs'
                    % (ii, toc3() / 100)))
                tic3()

            buddy.inc_train_iter()  # after finished training a mini-batch

        buddy.inc_epoch()  # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer,
                buddy.train_iter, {
                    'mean_%s' % name: value
                    for name, value in buddy.epoch_mean_list_re('^train_')
                },
                prefix='train')

    print('\nFinal')
    print(('%02d:%d val:   %s' %
           (buddy.epoch, buddy.train_iter,
            buddy.epoch_mean_pretty_re('^val_', style=val_style))))
    print(('%02d:%d train: %s' %
           (buddy.epoch, buddy.train_iter,
            buddy.epoch_mean_pretty_re('^train_', style=train_style))))

    print(
        '\nEnd of training. Saving evaluation results on whole train and val set.'
    )

    if args.output:
        writer.close()  # Flush and close
Ejemplo n.º 6
0
lab_root = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')
sys.path.insert(1, lab_root)
#
from general.util import mkdir_p

train_size = 50000
test_size = 200
img_size = 64
size = 10

#dirs = os.path.join("./", "sort_of_clevr")
dirs = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                    "sort_of_clevr")

mkdir_p(dirs)

# colors in rgb
colors = [
    (0, 0, 255),  ##b
    #    (0,255,0),##g
    (255, 0, 0),  ##r
    #    (255,156,0),##o
    #    (128,128,128),##k
    #    (255,255,0)##y
]


def center_generate(objects):
    while True:
        pas = True
Ejemplo n.º 7
0
def main():
    parser = make_standard_parser('Random Projection Experiments.',
                                  arch_choices=arch_choices)

    parser.add_argument('--vsize',
                        type=int,
                        default=100,
                        help='Dimension of intrinsic parmaeter space.')
    parser.add_argument('--d_rate',
                        '--dr',
                        type=float,
                        default=0.0,
                        help='Dropout rate.')
    parser.add_argument('--depth',
                        type=int,
                        default=2,
                        help='Number of layers in FNN.')
    parser.add_argument('--width',
                        type=int,
                        default=100,
                        help='Width of  layers in FNN.')
    parser.add_argument('--minibatch',
                        '--mb',
                        type=int,
                        default=128,
                        help='Size of minibatch.')
    parser.add_argument('--lr_ratio',
                        '--lrr',
                        type=float,
                        default=.1,
                        help='Ratio to decay LR by every LR_EPSTEP epochs.')
    parser.add_argument(
        '--lr_epochs',
        '--lrep',
        type=float,
        default=0,
        help='Decay LR every LR_EPSTEP epochs. 0 to turn off decay.')
    parser.add_argument('--lr_steps',
                        '--lrst',
                        type=float,
                        default=3,
                        help='Max LR steps.')

    parser.add_argument('--c1',
                        type=int,
                        default=6,
                        help='Channels in first conv layer, for LeNet.')
    parser.add_argument('--c2',
                        type=int,
                        default=16,
                        help='Channels in second conv layer, for LeNet.')
    parser.add_argument('--d1',
                        type=int,
                        default=120,
                        help='Channels in first dense layer, for LeNet.')
    parser.add_argument('--d2',
                        type=int,
                        default=84,
                        help='Channels in second dense layer, for LeNet.')

    parser.add_argument('--denseproj',
                        action='store_true',
                        help='Use a dense projection.')
    parser.add_argument('--sparseproj',
                        action='store_true',
                        help='Use a sparse projection.')
    parser.add_argument('--fastfoodproj',
                        action='store_true',
                        help='Use a fastfood projection.')

    parser.add_argument('--partial_data',
                        '--pd',
                        type=float,
                        default=1.0,
                        help='Percentage of dataset.')

    parser.add_argument(
        '--skiptfevents',
        action='store_true',
        help='Skip writing tf events files even if output is used.')

    args = parser.parse_args()

    n_proj_specified = sum(
        [args.denseproj, args.sparseproj, args.fastfoodproj])
    if args.arch in arch_choices_projected:
        assert n_proj_specified == 1, 'Arch "%s" requires projection. Specify exactly one of {denseproj, sparseproj, fastfoodproj} options.' % args.arch
    else:
        assert n_proj_specified == 0, 'Arch "%s" does not require projection, so do not specify any of {denseproj, sparseproj, fastfoodproj} options.' % args.arch

    if args.denseproj:
        proj_type = 'dense'
    elif args.sparseproj:
        proj_type = 'sparse'
    else:
        proj_type = 'fastfood'

    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)

    # Get a TF session registered with Keras and set numpy and TF seeds
    sess = setup_session_and_seeds(args.seed)

    # 0. LOAD DATA
    train_h5 = h5py.File(args.train_h5, 'r')
    train_x = train_h5['images']
    train_y = train_h5['labels']
    val_h5 = h5py.File(args.val_h5, 'r')
    val_x = val_h5['images']
    val_y = val_h5['labels']

    # loadpath = "./dataset/ag_news.p"
    # x = pickle.load(open(loadpath, "rb"))
    # train_x, val_x, test_x = x[0], x[1], x[2]
    # train_y, val_y, test_y = x[3], x[4], x[5]
    # wordtoix, ixtoword = x[6], x[7]

    # train_x = prepare_data_for_cnn(train_x, 100, 5)
    # val_x = prepare_data_for_cnn(val_x, 100, 5)

    # #weightInit = tf.random_uniform_initializer(-0.001, 0.001)
    # #W = tf.get_variable('W', [13010, 300], initializer=weightInit)

    # W = np.random.rand(13010, 300)

    # #pdb.set_trace()

    # train_x = [np.take(W, i, axis=0) for i in train_x]
    # train_x = np.array(train_x, dtype='float32')

    # val_x = [np.take(W, i, axis=0) for i in val_x]
    # val_x = np.array(val_x, dtype='float32')

    #pdb.set_trace()

    train_x = np.array(train_x, dtype='float32')
    val_x = np.array(val_x, dtype='float32')

    if args.partial_data < 1.0:
        n_train_ = int(train_y.size * args.partial_data)
        n_test_ = int(val_y.size * args.partial_data)
        train_x = train_x[:n_train_]
        train_y = train_y[:n_train_]
        val_x = val_x[:n_test_]
        val_y = val_y[:n_test_]

    # load into memory if less than 1 GB
    if train_x.size * 4 + val_x.size * 4 < 1e9:
        train_x, train_y = np.array(train_x), np.array(train_y)
        val_x, val_y = np.array(val_x), np.array(val_y)

    # 1. CREATE MODEL
    randmirrors = False
    randcrops = False
    cropsize = None

    with WithTimer('Make model'):
        if args.arch == 'mnistfc_dir':
            model = build_model_mnist_fc_dir(weight_decay=args.l2,
                                             depth=args.depth,
                                             width=args.width)
        elif args.arch == 'mnistfc':
            if proj_type == 'fastfood':
                model = build_model_mnist_fc_fastfood(weight_decay=args.l2,
                                                      vsize=args.vsize,
                                                      depth=args.depth,
                                                      width=args.width)
            else:
                model = build_model_mnist_fc(weight_decay=args.l2,
                                             vsize=args.vsize,
                                             depth=args.depth,
                                             width=args.width,
                                             proj_type=proj_type)
        elif args.arch == 'mnistconv':
            model = build_cnn_model_mnist(weight_decay=args.l2,
                                          vsize=args.vsize)
        elif args.arch == 'mnistconv_dir':
            model = build_cnn_model_direct_mnist(weight_decay=args.l2)
        elif args.arch == 'cifarfc_dir':
            model = build_model_cifar_fc_dir(weight_decay=args.l2,
                                             depth=args.depth,
                                             width=args.width)
        elif args.arch == 'cifarfc':
            if proj_type == 'fastfood':
                model = build_model_cifar_fc_fastfood(weight_decay=args.l2,
                                                      vsize=args.vsize,
                                                      depth=args.depth,
                                                      width=args.width)
            else:
                model = build_model_cifar_fc(weight_decay=args.l2,
                                             vsize=args.vsize,
                                             depth=args.depth,
                                             width=args.width,
                                             proj_type=proj_type)
        elif args.arch == 'mnistlenet_dir':
            model = build_LeNet_direct_mnist(weight_decay=args.l2,
                                             c1=args.c1,
                                             c2=args.c2,
                                             d1=args.d1,
                                             d2=args.d2)

        elif args.arch == 'mnistMLPlenet_dir':
            model = build_MLPLeNet_direct_mnist(weight_decay=args.l2)
        elif args.arch == 'mnistMLPlenet':
            if proj_type == 'fastfood':
                model = build_model_mnist_MLPLeNet_fastfood(
                    weight_decay=args.l2, vsize=args.vsize)

        elif args.arch == 'mnistUntiedlenet_dir':
            model = build_UntiedLeNet_direct_mnist(weight_decay=args.l2)
        elif args.arch == 'mnistUntiedlenet':
            if proj_type == 'fastfood':
                model = build_model_mnist_UntiedLeNet_fastfood(
                    weight_decay=args.l2, vsize=args.vsize)

        elif args.arch == 'cifarMLPlenet_dir':
            model = build_MLPLeNet_direct_cifar(weight_decay=args.l2)
        elif args.arch == 'cifarMLPlenet':
            if proj_type == 'fastfood':
                model = build_model_cifar_MLPLeNet_fastfood(
                    weight_decay=args.l2, vsize=args.vsize)

        elif args.arch == 'cifarUntiedlenet_dir':
            model = build_UntiedLeNet_direct_cifar(weight_decay=args.l2)
        elif args.arch == 'cifarUntiedlenet':
            if proj_type == 'fastfood':
                model = build_model_cifar_UntiedLeNet_fastfood(
                    weight_decay=args.l2, vsize=args.vsize)

        elif args.arch == 'mnistlenet':
            if proj_type == 'fastfood':
                model = build_model_mnist_LeNet_fastfood(weight_decay=args.l2,
                                                         vsize=args.vsize)
            else:
                model = build_LeNet_mnist(weight_decay=args.l2,
                                          vsize=args.vsize,
                                          proj_type=proj_type)

        elif args.arch == 'cifarlenet_dir':
            model = build_LeNet_direct_cifar(weight_decay=args.l2,
                                             d_rate=args.d_rate,
                                             c1=args.c1,
                                             c2=args.c2,
                                             d1=args.d1,
                                             d2=args.d2)
        elif args.arch == 'cifarlenet':
            if proj_type == 'fastfood':
                model = build_model_cifar_LeNet_fastfood(weight_decay=args.l2,
                                                         vsize=args.vsize,
                                                         d_rate=args.d_rate,
                                                         c1=args.c1,
                                                         c2=args.c2,
                                                         d1=args.d1,
                                                         d2=args.d2)
            else:
                model = build_LeNet_cifar(weight_decay=args.l2,
                                          vsize=args.vsize,
                                          proj_type=proj_type,
                                          d_rate=args.d_rate)
        elif args.arch == 'cifarDenseNet_dir':
            model = build_DenseNet_direct_cifar(weight_decay=args.l2,
                                                depth=25,
                                                nb_dense_block=1,
                                                growth_rate=12)
        elif args.arch == 'cifarDenseNet':
            if proj_type == 'fastfood':
                model = build_DenseNet_cifar_fastfood(weight_decay=args.l2,
                                                      vsize=args.vsize,
                                                      depth=25,
                                                      nb_dense_block=1,
                                                      growth_rate=12)

        elif args.arch == 'alexnet_dir':
            model = build_alexnet_direct(weight_decay=args.l2,
                                         shift_in=np.array([104, 117, 123]))
            args.shuffletrain = False
            randmirrors = True
            randcrops = True
            cropsize = (227, 227)

        elif args.arch == 'squeeze_dir':
            model = build_squeezenet_direct(weight_decay=args.l2,
                                            shift_in=np.array([104, 117, 123]))
            args.shuffletrain = False
            randmirrors = True
            randcrops = True
            cropsize = (224, 224)

        elif args.arch == 'alexnet':
            if proj_type == 'fastfood':
                model = build_alexnet_fastfood(weight_decay=args.l2,
                                               shift_in=np.array(
                                                   [104, 117, 123]),
                                               vsize=args.vsize)
            else:
                raise Exception('not implemented')
            args.shuffletrain = False
            randmirrors = True
            randcrops = True
            cropsize = (227, 227)
        else:
            raise Exception('Unknown network architecture: %s' % args.arch)

    print 'All model weights:'
    total_params = summarize_weights(model.trainable_weights)
    print 'Model summary:'
    model.summary()

    model.print_trainable_warnings()

    input_lr = tf.placeholder(tf.float32, shape=[])
    lr_stepper = LRStepper(args.lr, args.lr_ratio, args.lr_epochs,
                           args.lr_steps)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(input_lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2)

    # Optimize w.r.t all trainable params in the model
    grads_and_vars = opt.compute_gradients(
        model.v.loss,
        model.trainable_weights,
        gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    train_step = opt.apply_gradients(grads_and_vars)
    add_grad_summaries(grads_and_vars)
    summarize_opt(opt)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()
    buddy.tic()  # call if new run OR resumed run

    # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(uninitialized_vars,
                                                'init_missed_vars')

    sess.run(init_missed_vars)

    # Print warnings about any TF vs. Keras shape mismatches
    warn_misaligned_shapes(model)
    # Make sure all variables, which are model variables, have been initialized (e.g. model params and model running BN means)
    tf_assert_all_init(sess)

    # 3.5 Normalize the overall basis matrix across the (multiple) unnormalized basis matrices for each layer
    basis_matrices = []
    normalizers = []

    for layer in model.layers:
        try:
            basis_matrices.extend(layer.offset_creator.basis_matrices)
        except AttributeError:
            continue
        try:
            normalizers.extend(layer.offset_creator.basis_matrix_normalizers)
        except AttributeError:
            continue

    if len(basis_matrices) > 0 and not args.load:

        if proj_type == 'sparse':

            # Norm of overall basis matrix rows (num elements in each sum == total parameters in model)
            bm_row_norms = tf.sqrt(
                tf.add_n([
                    tf.sparse_reduce_sum(tf.square(bm), 1)
                    for bm in basis_matrices
                ]))
            # Assign `normalizer` Variable to these row norms to achieve normalization of the basis matrix
            # in the TF computational graph
            rescale_basis_matrices = [
                tf.assign(var, tf.reshape(bm_row_norms, var.shape))
                for var in normalizers
            ]
            _ = sess.run(rescale_basis_matrices)
        elif proj_type == 'dense':
            bm_sums = [
                tf.reduce_sum(tf.square(bm), 1) for bm in basis_matrices
            ]
            divisor = tf.expand_dims(tf.sqrt(tf.add_n(bm_sums)), 1)
            rescale_basis_matrices = [
                tf.assign(var, var / divisor) for var in basis_matrices
            ]
            _ = sess.run(rescale_basis_matrices)
        else:
            print '\nhere\n'
            embed()

            assert False, 'what to do with fastfood?'

    # 4. SETUP TENSORBOARD LOGGING
    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    val_histogram_summaries = get_collection_intersection_summary(
        'val_collection', 'orig_histogram')
    val_scalar_summaries = get_collection_intersection_summary(
        'val_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')

    writer = None
    if args.output:
        mkdir_p(args.output)
        if not args.skiptfevents:
            writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN
    train_iters = (train_y.shape[0] - 1) / args.minibatch + 1
    val_iters = (val_y.shape[0] - 1) / args.minibatch + 1
    impreproc = ImagePreproc()

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    fastest_avg_iter_time = 1e9

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        do_log_params = lambda ep, it, ii: False
        do_log_val = lambda ep, it, ii: True
        do_log_train = lambda ep, it, ii: (
            it < train_iters and it & it - 1 == 0 or it >= train_iters and it %
            train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch, buddy.train_iter, 0
        ) and param_histogram_summaries is not None and not args.skiptfevents:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Evaluate val set performance
        if not args.skipval:
            tic2()
            for ii in xrange(val_iters):
                start_idx = ii * args.minibatch
                batch_x = val_x[start_idx:start_idx + args.minibatch]
                batch_y = val_y[start_idx:start_idx + args.minibatch]
                if randcrops:
                    batch_x = impreproc.center_crops(batch_x, cropsize)
                feed_dict = {
                    model.v.input_images: batch_x,
                    model.v.input_labels: batch_y,
                    K.learning_phase(): 0,
                }
                fetch_dict = model.trackable_dict
                with WithTimer('sess.run val iter', quiet=not args.verbose):
                    result_val = sess_run_dict(sess,
                                               fetch_dict,
                                               feed_dict=feed_dict)

                buddy.note_weighted_list(
                    batch_x.shape[0],
                    model.trackable_names,
                    [result_val[k] for k in model.trackable_names],
                    prefix='val_')

            if args.output and not args.skiptfevents and do_log_val(
                    buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'mean_%s' % name: value
                        for name, value in buddy.epoch_mean_list_re('^val_')
                    },
                    prefix='buddy')

            print(
                '\ntime: %f. after training for %d epochs:\n%3d val:   %s (%.3gs/i)'
                % (buddy.toc(), buddy.epoch, buddy.train_iter,
                   buddy.epoch_mean_pretty_re(
                       '^val_', style=val_style), toc2() / val_iters))

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            snap_end = buddy.epoch == args.epochs
            if snap_intermed or snap_end:
                # Snapshot
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open(
                        '%s/%s_misc_%04d.pkl.gz' %
                    (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break  # Extra pass at end: just report val stats and skip training

        # 3. Train on training set
        #train_order = range(train_x.shape[0])
        if args.shuffletrain:
            train_order = np.random.permutation(train_x.shape[0])

        tic3()
        for ii in xrange(train_iters):
            tic2()
            start_idx = ii * args.minibatch

            if args.shuffletrain:
                batch_x = train_x[train_order[start_idx:start_idx +
                                              args.minibatch]]
                batch_y = train_y[train_order[start_idx:start_idx +
                                              args.minibatch]]
            else:
                batch_x = train_x[start_idx:start_idx + args.minibatch]
                batch_y = train_y[start_idx:start_idx + args.minibatch]
            if randcrops:
                batch_x = impreproc.random_crops(batch_x, cropsize,
                                                 randmirrors)
            feed_dict = {
                model.v.input_images: batch_x,
                model.v.input_labels: batch_y,
                input_lr: lr_stepper.lr(buddy),
                K.learning_phase(): 1,
            }

            fetch_dict = {'train_step': train_step}
            fetch_dict.update(model.trackable_and_update_dict)

            if args.output and not args.skiptfevents and do_log_train(
                    buddy.epoch, buddy.train_iter, ii):
                if param_histogram_summaries is not None:
                    fetch_dict.update({
                        'param_histogram_summaries':
                        param_histogram_summaries
                    })
                if train_histogram_summaries is not None:
                    fetch_dict.update({
                        'train_histogram_summaries':
                        train_histogram_summaries
                    })
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(sess,
                                             fetch_dict,
                                             feed_dict=feed_dict)

            buddy.note_weighted_list(
                batch_x.shape[0],
                model.trackable_names,
                [result_train[k] for k in model.trackable_names],
                prefix='train_')

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                print('%3d train: %s (%.3gs/i)' %
                      (buddy.train_iter,
                       buddy.epoch_mean_pretty_re('^train_',
                                                  style=train_style), toc2()))
                if args.output and not args.skiptfevents:
                    if param_histogram_summaries is not None:
                        hist_summary_str = result_train[
                            'param_histogram_summaries']
                        writer.add_summary(hist_summary_str, buddy.train_iter)
                    if train_histogram_summaries is not None:
                        hist_summary_str = result_train[
                            'train_histogram_summaries']
                        writer.add_summary(hist_summary_str, buddy.train_iter)
                    if train_scalar_summaries is not None:
                        scalar_summary_str = result_train[
                            'train_scalar_summaries']
                        writer.add_summary(scalar_summary_str,
                                           buddy.train_iter)
                    log_scalars(
                        writer,
                        buddy.train_iter, {
                            'batch_%s' % name: value
                            for name, value in buddy.last_list_re('^train_')
                        },
                        prefix='buddy')

            if ii > 0 and ii % 100 == 0:
                avg_iter_time = toc3() / 100
                tic3()
                fastest_avg_iter_time = min(fastest_avg_iter_time,
                                            avg_iter_time)
                print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                    ii, avg_iter_time)

            buddy.inc_train_iter()  # after finished training a mini-batch

        buddy.inc_epoch()  # after finished training whole pass through set

        if args.output and not args.skiptfevents and do_log_train(
                buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer,
                buddy.train_iter, {
                    'mean_%s' % name: value
                    for name, value in buddy.epoch_mean_list_re('^train_')
                },
                prefix='buddy')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^val_',
                                                            style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^train_',
                                                            style=train_style))

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    print 'final_stats total_params %g' % total_params
    print 'final_stats fastest_avg_iter_time %g' % fastest_avg_iter_time
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output and not args.skiptfevents:
        writer.close()  # Flush and close
Ejemplo n.º 8
0
def main():
    parser = make_standard_parser('Low Rank Basis experiments.', skip_train=True, skip_val=True, arch_choices=['one'])

    parser.add_argument('--DD', type=int, default=1000, help='Dimension of full parameter space.')
    parser.add_argument('--vsize', type=int, default=100, help='Dimension of intrinsic parameter space.')
    parser.add_argument('--lr_ratio', '--lrr', type=float, default=.5, help='Ratio to decay LR by every LR_EPSTEP epochs.')
    parser.add_argument('--lr_epochs', '--lrep', type=float, default=0, help='Decay LR every LR_EPSTEP epochs. 0 to turn off decay.')
    parser.add_argument('--lr_steps', '--lrst', type=float, default=3, help='Max LR steps.')
    
    parser.add_argument('--denseproj', action='store_true', help='Use a dense projection.')

    parser.add_argument('--skiptfevents', action='store_true', help='Skip writing tf events files even if output is used.')

    args = parser.parse_args()

    if args.denseproj:
        proj_type = 'dense'
    else:
        proj_type = None

    train_style, val_style = ('', '') if args.nocolor else (colorama.Fore.BLUE, colorama.Fore.MAGENTA)

    # Get a TF session registered with Keras and set numpy and TF seeds
    sess = setup_session_and_seeds(args.seed)

    # 1. CREATE MODEL

    with WithTimer('Make model'):
        if args.denseproj:
            model = build_toy(weight_decay=args.l2, DD=args.DD, groups=10, vsize=args.vsize, proj=True)
        else:
            model = build_toy(weight_decay=args.l2, DD=args.DD, proj=False)

    print 'All model weights:'
    total_params = summarize_weights(model.trainable_weights)
    print 'Model summary:'
    model.summary()
    model.print_trainable_warnings()

    input_lr = tf.placeholder(tf.float32, shape=[])
    lr_stepper = LRStepper(args.lr, args.lr_ratio, args.lr_epochs, args.lr_steps)
    
    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(input_lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2)

    # Optimize w.r.t all trainable params in the model
    grads_and_vars = opt.compute_gradients(model.v.loss, model.trainable_weights, gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    train_step = opt.apply_gradients(grads_and_vars)
    add_grad_summaries(grads_and_vars)
    summarize_opt(opt)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()
    buddy.tic()    # call if new run OR resumed run

    # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(uninitialized_vars, 'init_missed_vars')

    sess.run(init_missed_vars)


    # Print warnings about any TF vs. Keras shape mismatches
    warn_misaligned_shapes(model)
    # Make sure all variables, which are model variables, have been initialized (e.g. model params and model running BN means)
    tf_assert_all_init(sess)


    # Choose between sparsified and dense projection matrix if using them
    #SparseRM = True

    # 3.5 Normalize the overall basis matrix across the (multiple) unnormalized basis matrices for each layer
    basis_matrices = []
    normalizers = []
    
    for layer in model.layers:
        try:
            basis_matrices.extend(layer.offset_creator.basis_matrices)
        except AttributeError:
            continue
        try:
            normalizers.extend(layer.offset_creator.basis_matrix_normalizers)
        except AttributeError:
            continue

    if len(basis_matrices) > 0 and not args.load:

        if proj_type == 'sparse':
            # Norm of overall basis matrix rows (num elements in each sum == total parameters in model)
            bm_row_norms = tf.sqrt(tf.add_n([tf.sparse_reduce_sum(tf.square(bm), 1) for bm in basis_matrices]))
            # Assign `normalizer` Variable to these row norms to achieve normalization of the basis matrix
            # in the TF computational graph
            rescale_basis_matrices = [tf.assign(var, tf.reshape(bm_row_norms,var.shape)) for var in normalizers]
            _ = sess.run(rescale_basis_matrices)
        elif proj_type == 'dense':
            bm_sums = [tf.reduce_sum(tf.square(bm), 1) for bm in basis_matrices]
            divisor = tf.expand_dims(tf.sqrt(tf.add_n(bm_sums)), 1)
            rescale_basis_matrices = [tf.assign(var, var / divisor) for var in basis_matrices]
            sess.run(rescale_basis_matrices)
        else:
            print '\nhere\n'
            embed()

            assert False, 'what to do with fastfood?'

    # 4. SETUP TENSORBOARD LOGGING
    train_histogram_summaries = get_collection_intersection_summary('train_collection', 'orig_histogram')
    train_scalar_summaries    = get_collection_intersection_summary('train_collection', 'orig_scalar')
    val_histogram_summaries   = get_collection_intersection_summary('val_collection', 'orig_histogram')
    val_scalar_summaries      = get_collection_intersection_summary('val_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary('param_collection', 'orig_histogram')

    writer = None
    if args.output:
        mkdir_p(args.output)
        if not args.skiptfevents:
            writer = tf.summary.FileWriter(args.output, sess.graph)



    # 5. TRAIN
    train_iters = 1
    val_iters = 1

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    fastest_avg_iter_time = 1e9
    
    while buddy.epoch < args.epochs + 1:
        # How often to log data
        do_log_params = lambda ep, it, ii: False
        do_log_val = lambda ep, it, ii: True
        do_log_train = lambda ep, it, ii: (it < train_iters and it & it-1 == 0 or it>=train_iters and it%train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(buddy.epoch, buddy.train_iter, 0) and param_histogram_summaries is not None and not args.skiptfevents:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Evaluate val set performance
        if not args.skipval:
            tic2()
            for ii in xrange(val_iters):
                with WithTimer('val iter %d/%d'%(ii, val_iters), quiet=not args.verbose):
                    feed_dict = {
                        K.learning_phase(): 0,
                    }
                    fetch_dict = model.trackable_dict
                    with WithTimer('sess.run val iter', quiet=not args.verbose):
                        result_val = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict)

                    buddy.note_weighted_list(1, model.trackable_names, [result_val[k] for k in model.trackable_names], prefix='val_')

            if args.output and not args.skiptfevents and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(writer, buddy.train_iter,
                            {'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^val_')},
                            prefix='buddy')

            print ('\ntime: %f. after training for %d epochs:\n%3d val:   %s (%.3gs/i)'
                   % (buddy.toc(), buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style), toc2() / val_iters))

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            snap_end = buddy.epoch == args.epochs
            if snap_intermed or snap_end:
                # Snapshot
                save_path = saver.save(sess, '%s/%s_%04d.ckpt' % (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open('%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break   # Extra pass at end: just report val stats and skip training

        # 3. Train on training set
        #train_order = range(train_x.shape[0])
        tic3()
        for ii in xrange(train_iters):

            with WithTimer('train iter %d/%d'%(ii, train_iters), quiet=not args.verbose):
            
                tic2()
                feed_dict = {
                    input_lr: lr_stepper.lr(buddy),
                    K.learning_phase(): 1,
                }

                fetch_dict = {'train_step': train_step}
                fetch_dict.update(model.trackable_and_update_dict)

                if args.output and not args.skiptfevents and do_log_train(buddy.epoch, buddy.train_iter, ii):
                    if param_histogram_summaries is not None:
                        fetch_dict.update({'param_histogram_summaries': param_histogram_summaries})
                    if train_histogram_summaries is not None:
                        fetch_dict.update({'train_histogram_summaries': train_histogram_summaries})
                    if train_scalar_summaries is not None:
                        fetch_dict.update({'train_scalar_summaries': train_scalar_summaries})

                with WithTimer('sess.run train iter', quiet=not args.verbose):
                    result_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict)

                buddy.note_weighted_list(1, model.trackable_names, [result_train[k] for k in model.trackable_names], prefix='train_')

                if do_log_train(buddy.epoch, buddy.train_iter, ii):
                    print ('%3d train: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style), toc2()))
                    if args.output and not args.skiptfevents:
                        if param_histogram_summaries is not None:
                            hist_summary_str = result_train['param_histogram_summaries']
                            writer.add_summary(hist_summary_str, buddy.train_iter)
                        if train_histogram_summaries is not None:
                            hist_summary_str = result_train['train_histogram_summaries']
                            writer.add_summary(hist_summary_str, buddy.train_iter)
                        if train_scalar_summaries is not None:
                            scalar_summary_str = result_train['train_scalar_summaries']
                            writer.add_summary(scalar_summary_str, buddy.train_iter)
                        log_scalars(writer, buddy.train_iter,
                                    {'batch_%s' % name: value for name, value in buddy.last_list_re('^train_')},
                                    prefix='buddy')
                        log_scalars(writer, buddy.train_iter, {'batch_lr': lr_stepper.lr(buddy)}, prefix='buddy')

                if ii > 0 and ii % 100 == 0:
                    avg_iter_time = toc3() / 100; tic3()
                    fastest_avg_iter_time = min(fastest_avg_iter_time, avg_iter_time)
                    print '  %d: Average iteration time over last 100 train iters: %.3gs' % (ii, avg_iter_time)

                buddy.inc_train_iter()   # after finished training a mini-batch

        buddy.inc_epoch()   # after finished training whole pass through set

        if args.output and not args.skiptfevents and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(writer, buddy.train_iter,
                        {'mean_%s' % name: value for name,value in buddy.epoch_mean_list_re('^train_')},
                        prefix='buddy')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style))

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    print 'final_stats total_params %g' % total_params
    print 'final_stats fastest_avg_iter_time %g' % fastest_avg_iter_time
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output and not args.skiptfevents:
        writer.close()   # Flush and close