コード例 #1
0
    def federated_step(self, fitter):
        """
        Run a federated step.
        """
        # self.model_manager.set_fitter(fitter)
        client_local_rank = fitter.train_ctx.my_rank
        if client_local_rank == 0:
            # 获取的数据都是列表的形式
            pull_success, send_fake = self.pull_models()
        else:
            pull_success, send_fake = False, False

        if fitter.multi_gpu:

            assert fitter.multi_gpu, "目前不支持单机使用的 update signal"
            pull_success = comm.bcast(pull_success, root=0)
            self.train_end = comm.bcast(self.train_end, root=0)

        if pull_success:  # pull models from multiple servers
            if fitter.multi_gpu:
                hvd.broadcast_global_variables(root_rank=0)
            if self.model_manager.train():  # do local fitting
                if client_local_rank == 0:
                    self.push_models()  # push models
        elif send_fake:
            if client_local_rank == 0:
                self.push_models(push_fake=True)  # push models
コード例 #2
0
    def test_horovod_broadcast_graph_mode(self):
        """Test that tries to broadcast tensorflow global variables
        in graph execution mode. This call should not raise any exception."""

        if hvd.util._executing_eagerly():
            self.skipTest("Not in eager execution mode")

        hvd.broadcast_global_variables(root_rank=0)
コード例 #3
0
    def test_horovod_broadcast_eager_mode_error(self):
        """Test that tries to broadcast tensorflow global variables
        in eager execution mode. This call should raise a RuntimeError."""

        if not hvd.util._executing_eagerly():
            self.skipTest("Only in eager execution mode")

        with self.assertRaises(RuntimeError):
            hvd.broadcast_global_variables(root_rank=0)
コード例 #4
0
def train():
    coord = tf.train.Coordinator()
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        learning_rate = tf.convert_to_tensor(0.02)

        w = tf.Variable([[0]], trainable=True, dtype=tf.float32)

        x = np.zeros((1,1), dtype=np.float32)
        y = np.ones((1,1), dtype=np.float32)
        x = tf.convert_to_tensor(x)
        y = tf.convert_to_tensor(y)
        y_pred = tf.matmul(x, w)
        mse = tf.losses.mean_squared_error(y, y_pred)

        optimizer = tf.train.AdamOptimizer(learning_rate)
        optimizer = hvd.DistributedOptimizer(optimizer)
        optimize = optimizer.minimize(mse, var_list=w)

  # Train!
    with tf.Session(config=config) as sess:
        try:
            sess.run(tf.global_variables_initializer())
            bcast = hvd.broadcast_global_variables(0)
            bcast.run()
            sess.run([optimize])
        except Exception as e:
            print('Exiting due to exception: %s' % e)
            traceback.print_exc()
            coord.request_stop(e)
コード例 #5
0
    def broadcast_global_variables(cls, *args):
        """Get a TensorFlow operation to broadcast all the global variables."""

        try:
            return mgw.broadcast_global_variables(*args)
        except NameError:
            raise NameError('module <mgw> not imported')
コード例 #6
0
ファイル: hooks.py プロジェクト: mnslarcher/automl
    def begin(self):
        if not self.bcast_op or self.bcast_op.graph != tf.get_default_graph():
            with tf.device(self.device):
                self.bcast_op = hvd.broadcast_global_variables(self.root_rank)

        if self._model_dir is not None:
            checkpoint_path = checkpoint_management.latest_checkpoint(
                self._model_dir)
            if checkpoint_path is not None and not checkpoint_path.endswith(
                    'model.ckpt-0'):
                hvd_info_rank0(
                    '>>>>> model_dir {} has checkpoint {}, not using pretrained_model_path <<<<<'
                    .format(self._model_dir, checkpoint_path))
                return

        if self._pretrained_model_path is not None and len(
                self._pretrained_model_path) > 0 and is_rank0():
            reader = pywrap_tensorflow.NewCheckpointReader(
                self._pretrained_model_path)
            var_to_shape_map = sorted(reader.get_variable_to_shape_map())

            self._exclusions.add('global_step')

            for var in tf.global_variables():
                if var.op.name in var_to_shape_map:
                    excluded = False
                    for exclusion in self._exclusions:
                        if var.op.name.startswith(exclusion):
                            excluded = True
                            break
                    if not excluded:
                        self._variables_to_restore.append(var)

            self._saver = tf.train.Saver(var_list=self._variables_to_restore)
コード例 #7
0
    def __init__(self, device=''):
        hvd.init()
        with tf.device(device):
            self._bcast_op = hvd.broadcast_global_variables(0)
            self._exit_op = hvd.join()

        self._broadcast_done = False
        super(HorovodSyncHook, self).__init__()
コード例 #8
0
ファイル: __init__.py プロジェクト: kioco/horovod
def broadcast_global_variables(root_rank):
    """Broadcasts all global variables from root rank to all other processes.

    Arguments:
        root_rank: Rank of the process from which global variables will be broadcasted
                   to all other processes.
    """
    bcast_op = hvd.broadcast_global_variables(root_rank)
    return K.get_session().run(bcast_op)
コード例 #9
0
def broadcast_global_variables(root_rank):
    """Broadcasts all global variables from root rank to all other processes.

    Arguments:
        root_rank: Rank of the process from which global variables will be broadcasted
                   to all other processes.
    """
    bcast_op = hvd.broadcast_global_variables(root_rank)
    return K.get_session().run(bcast_op)
コード例 #10
0
    def __init__(self,
                 model_name,
                 tokenize=None,
                 pbmodel_dir=None,
                 use_hvd=False):
        # 维护sess graph config saver
        self.model_name = model_name
        if tokenize is None:
            self.jieba = jieba.Tokenizer()
            # self.jieba.load_userdict(f'{curr_dir}/../data/segword.dct')
            self.tokenize = lambda t: self.jieba.lcut(re.sub(r'\s+', ',', t))
        else:
            self.tokenize = tokenize
        self.cut = lambda t: ' '.join(self.tokenize(t))
        self.token2id_dct = {
            # 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/mmch_word2id.dct', use_line_no=True),  # 自有数据
            # 'word2id': utils.Any2Id.from_file(f'{curr_dir}/../data/mmch_char2id.dct', use_line_no=True),  # 自有数据
            'word2id':
            utils.Any2Id.from_file(f'{curr_dir}/../data/DB_mmch_word2id.dct',
                                   use_line_no=True),  # 豆瓣多轮语料
            'char2id':
            utils.Any2Id.from_file(f'{curr_dir}/../data/DB_mmch_char2id.dct',
                                   use_line_no=True),  # 豆瓣多轮语料
        }
        self.config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
        )
        self.use_hvd = use_hvd if HVD_ENABLE else False
        if self.use_hvd:
            hvd.init()
            self.hvd_rank = hvd.rank()
            self.hvd_size = hvd.size()
            self.config.gpu_options.visible_device_list = str(hvd.local_rank())
        self.graph = tf.Graph()
        self.sess = tf.Session(config=self.config, graph=self.graph)

        if pbmodel_dir is not None:  # 只能做predict
            self.model = MMCH_Model.from_pbmodel(pbmodel_dir, self.sess)
        else:
            with self.graph.as_default():
                self.model = MMCH_Model(model_name=self.model_name,
                                        run_model=self)
                if self.use_hvd:
                    self.model.optimizer._lr = self.model.optimizer._lr * self.hvd_size  # 分布式训练大batch增大学习率
                    self.model.hvd_optimizer = hvd.DistributedOptimizer(
                        self.model.optimizer)
                    self.model.train_op = self.model.hvd_optimizer.minimize(
                        self.model.loss, global_step=self.model.global_step)
                self.sess.run(tf.global_variables_initializer())
                if self.use_hvd:
                    self.sess.run(hvd.broadcast_global_variables(0))

        with self.graph.as_default():
            self.saver = tf.train.Saver(
                max_to_keep=100)  # must in the graph context
コード例 #11
0
ファイル: train.py プロジェクト: rickerliang/inc_2018
    def train(self):

        with tf.Session(config=self.session_config) as session:
            print("session run...")
            if self.args.tfdbg:
                session = tf_debug.LocalCLIDebugWrapperSession(session)

            session.run(tf.global_variables_initializer())
            self.model.load_pretrained_weight(session)
            if self.args.use_horovod:
                session.run(hvd.broadcast_global_variables(0))

            self.training_process(session)

            print("train complete")
コード例 #12
0
ファイル: callbacks.py プロジェクト: eshnil2000/horovod
    def on_batch_end(self, batch, logs=None):
        if self.broadcast_done:
            return

        with tf.device(self.device):
            if hvd._executing_eagerly() and hasattr(self.model, 'variables'):
                # TensorFlow 2.0 or TensorFlow eager
                hvd.broadcast_variables(self.model.variables,
                                        root_rank=self.root_rank)
                hvd.broadcast_variables(self.model.optimizer.variables(),
                                        root_rank=self.root_rank)
            else:
                bcast_op = hvd.broadcast_global_variables(self.root_rank)
                self.backend.get_session().run(bcast_op)

        self.broadcast_done = True
コード例 #13
0
    def set_model(self, model):
        self.model = model
        self.graph = tf.Graph()
        with self.graph.as_default():
            #Horovod added: Normal workflow
            config1 = tf.ConfigProto(log_device_placement=False)
            config1.gpu_options.allow_growth = True
            config1.gpu_options.visible_device_list = str(hvd.local_rank())
            self.sess = tf.Session(config=config1)
            #Horovod end
            with self.sess.as_default():
                initializer = tf.contrib.layers.xavier_initializer(
                    uniform=True)
                with tf.variable_scope("model",
                                       reuse=None,
                                       initializer=initializer):
                    self.trainModel = self.model(config=self)
                    #Horovod added: Vary the learning rate, dist optimizer
                    if self.optimizer != None:
                        pass
                    elif self.opt_method == "Adagrad" or self.opt_method == "adagrad":
                        self.optimizer = tf.train.AdagradOptimizer(
                            learning_rate=self.alpha * hvd.size(),
                            initial_accumulator_value=1e-20)
                    elif self.opt_method == "Adadelta" or self.opt_method == "adadelta":
                        self.optimizer = tf.train.AdadeltaOptimizer(
                            self.alpha * hvd.size())
                    elif self.opt_method == "Adam" or self.opt_method == "adam":
                        self.optimizer = tf.train.AdamOptimizer(self.alpha *
                                                                hvd.size())
                    else:
                        self.optimizer = tf.train.GradientDescentOptimizer(
                            self.alpha * hvd.size())
                    self.dist_optimizer = hvd.DistributedOptimizer(
                        self.optimizer)
                    self.train_op = self.dist_optimizer.minimize(
                        self.trainModel.loss)
                    #Horovod end
                if (hvd.rank() == 0):
                    self.saver = tf.train.Saver()
                    # self.logSummary = tf.summary.scalar('Train_loss', self.trainModel.loss)
                    # self.train_writer = tf.summary.FileWriter('./train', self.sess.graph)
                self.sess.run(tf.global_variables_initializer())

                #Horovod added: Normal workflow
                self.sess.run(hvd.broadcast_global_variables(0))
コード例 #14
0
 def _start_train(self, hvd, sess):
     graph = tf.get_default_graph()
     saver = tf.train.Saver(max_to_keep=5000)
     with graph.as_default() as graph:
         global_init_fn = tf.global_variables_initializer()
         local_init_fn = tf.local_variables_initializer()
         init_fn = tf.group(global_init_fn, local_init_fn)
         all_ckpt_list = [
             _.split(".index")[0]
             for _ in list_getter(self.config.ckpt_dir, 'index')
         ]
         sess.run(init_fn)
         if all_ckpt_list:  # assumed the current model is intended to continue training if latest checkpoint exists
             print('Training will be continued from the last checkpoint...')
             saver.restore(sess, all_ckpt_list[-1])
             print('The last checkpoint is loaded!')
         else:
             print('Training will be started from scratch...')
         sess.run(hvd.broadcast_global_variables(0))
         self._train_step(graph, sess, saver)
コード例 #15
0
    def finalize(self, load_path, adam_epsilon):
        opt = tf.train.AdamOptimizer(self.LR, epsilon=adam_epsilon)
        if not self.disable_hvd:
            opt = hvd.DistributedOptimizer(opt)
        self.train_op = opt.minimize(self.loss)
        self.step = self.act_model.step
        self.step_fake_action = self.act_model.step_fake_action
        self.value = self.act_model.value
        self.initial_state = self.act_model.initial_state
        self.sess.run(tf.global_variables_initializer())
        if load_path and hvd.rank() == 0:
            self.load(load_path)
        if not self.disable_hvd:
            self.sess.run(hvd.broadcast_global_variables(0))
        tf.get_default_graph().finalize()

        self.loss_requested_dict = {self.pg_loss: 'policy_loss',
                                    self.vf_loss: 'value_loss',
                                    self.l2_loss: 'l2_loss',
                                    self.entropy: 'policy_entropy',
                                    self.approxkl: 'approxkl',
                                    self.clipfrac: 'clipfrac',
                                    self.train_op: ''}
        self.init_requested_loss()
コード例 #16
0
def main(args):

    network = importlib.import_module(args.model_def)
    image_size = (args.image_size, args.image_size)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        try:
            os.makedirs(log_dir)
        except OSError as exc:
            print(exc)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        try:
            os.makedirs(model_dir)
        except OSError as exc:
            print(exc)

    stat_file_name = os.path.join(log_dir, 'stat.h5')

    # Write arguments to a text file
    facenet.write_arguments_to_file(args, os.path.join(log_dir,
                                                       'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    facenet.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)
    random.seed(args.seed)
    dataset = facenet.get_dataset(args.data_dir)
    if args.filter_filename:
        dataset = filter_dataset(dataset,
                                 os.path.expanduser(args.filter_filename),
                                 args.filter_percentile,
                                 args.filter_min_nrof_images_per_class)

    if args.validation_set_split_ratio > 0.0:
        train_set, val_set = facenet.split_dataset(
            dataset, args.validation_set_split_ratio,
            args.min_nrof_val_images_per_class, 'SPLIT_IMAGES')
    else:
        train_set, val_set = dataset, []

    nrof_classes = len(train_set)

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    pretrained_model = None
    if args.pretrained_model:
        pretrained_model = os.path.expanduser(args.pretrained_model)
        print('Pre-trained model: %s' % pretrained_model)

    if args.lfw_dir:
        print('LFW directory: %s' % args.lfw_dir)
        # Read the file containing the pairs used for testing
        pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs))
        # Get the paths for the corresponding images
        lfw_paths, actual_issame = lfw.get_paths(
            os.path.expanduser(args.lfw_dir), pairs)

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False)

        # Get a list of image paths and their labels
        image_list, label_list = facenet.get_image_paths_and_labels(train_set)
        assert len(image_list) > 0, 'The training set should not be empty'

        val_image_list, val_label_list = facenet.get_image_paths_and_labels(
            val_set)

        # Create a queue that produces indices into the image_list and label_list
        labels = ops.convert_to_tensor(label_list, dtype=tf.int32)
        range_size = array_ops.shape(labels)[0]
        index_queue = tf.train.range_input_producer(range_size,
                                                    num_epochs=None,
                                                    shuffle=True,
                                                    seed=None,
                                                    capacity=32)

        index_dequeue_op = index_queue.dequeue_many(
            args.batch_size * args.epoch_size, 'index_dequeue')

        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')
        batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        image_paths_placeholder = tf.placeholder(tf.string,
                                                 shape=(None, 1),
                                                 name='image_paths')
        labels_placeholder = tf.placeholder(tf.int32,
                                            shape=(None, 1),
                                            name='labels')
        control_placeholder = tf.placeholder(tf.int32,
                                             shape=(None, 1),
                                             name='control')

        nrof_preprocess_threads = 4
        input_queue = data_flow_ops.FIFOQueue(
            capacity=2000000,
            dtypes=[tf.string, tf.int32, tf.int32],
            shapes=[(1, ), (1, ), (1, )],
            shared_name=None,
            name=None)
        enqueue_op = input_queue.enqueue_many(
            [image_paths_placeholder, labels_placeholder, control_placeholder],
            name='enqueue_op')
        image_batch, label_batch = facenet.create_input_pipeline(
            input_queue, image_size, nrof_preprocess_threads,
            batch_size_placeholder)

        image_batch = tf.identity(image_batch, 'image_batch')
        image_batch = tf.identity(image_batch, 'input')
        label_batch = tf.identity(label_batch, 'label_batch')

        print('Number of classes in training set: %d' % nrof_classes)
        print('Number of examples in training set: %d' % len(image_list))

        print('Number of classes in validation set: %d' % len(val_set))
        print('Number of examples in validation set: %d' % len(val_image_list))

        print('Building training graph')

        # Build the inference graph
        prelogits, _ = network.inference(
            image_batch,
            args.keep_probability,
            phase_train=phase_train_placeholder,
            bottleneck_layer_size=args.embedding_size,
            weight_decay=args.weight_decay)
        logits = slim.fully_connected(
            prelogits,
            len(train_set),
            activation_fn=None,
            weights_initializer=slim.initializers.xavier_initializer(),
            weights_regularizer=slim.l2_regularizer(args.weight_decay),
            scope='Logits',
            reuse=False)

        embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')

        # Norm for the prelogits
        eps = 1e-4
        prelogits_norm = tf.reduce_mean(
            tf.norm(tf.abs(prelogits) + eps, ord=args.prelogits_norm_p,
                    axis=1))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             prelogits_norm * args.prelogits_norm_loss_factor)

        # Add center loss
        prelogits_center_loss, _ = facenet.center_loss(prelogits, label_batch,
                                                       args.center_loss_alfa,
                                                       nrof_classes)
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             prelogits_center_loss * args.center_loss_factor)

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        # Calculate the average cross entropy loss across the batch
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=label_batch,
            logits=logits,
            name='cross_entropy_per_example')
        cross_entropy_mean = tf.reduce_mean(cross_entropy,
                                            name='cross_entropy')
        tf.add_to_collection('losses', cross_entropy_mean)

        correct_prediction = tf.cast(
            tf.equal(tf.argmax(logits, 1), tf.cast(label_batch, tf.int64)),
            tf.float32)
        accuracy = tf.reduce_mean(correct_prediction)

        # Calculate the total losses
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([cross_entropy_mean] + regularization_losses,
                              name='total_loss')

        # Build a Graph that trains the model with one batch of examples and updates the model parameters
        train_op = facenet.train(total_loss, global_step, args.optimizer,
                                 learning_rate, args.moving_average_decay,
                                 tf.global_variables(), args.log_histograms)

        # Create a saver
        saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        config = tf.ConfigProto(gpu_options=gpu_options,
                                log_device_placement=False)
        config.gpu_options.visible_device_list = str(hvd.local_rank())
        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        bcast = hvd.broadcast_global_variables(0)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():

            if pretrained_model:
                print('Restoring pretrained model: %s' % pretrained_model)
                saver.restore(sess, pretrained_model)

            # Training and validation loop
            print('Running training')
            nrof_steps = args.max_nrof_epochs * args.epoch_size
            nrof_val_samples = int(
                math.ceil(args.max_nrof_epochs / args.validate_every_n_epochs)
            )  # Validate every validate_every_n_epochs as well as in the last epoch
            stat = {
                'loss':
                np.zeros((nrof_steps, ), np.float32),
                'center_loss':
                np.zeros((nrof_steps, ), np.float32),
                'reg_loss':
                np.zeros((nrof_steps, ), np.float32),
                'xent_loss':
                np.zeros((nrof_steps, ), np.float32),
                'prelogits_norm':
                np.zeros((nrof_steps, ), np.float32),
                'accuracy':
                np.zeros((nrof_steps, ), np.float32),
                'val_loss':
                np.zeros((nrof_val_samples, ), np.float32),
                'val_xent_loss':
                np.zeros((nrof_val_samples, ), np.float32),
                'val_accuracy':
                np.zeros((nrof_val_samples, ), np.float32),
                'lfw_accuracy':
                np.zeros((args.max_nrof_epochs, ), np.float32),
                'lfw_valrate':
                np.zeros((args.max_nrof_epochs, ), np.float32),
                'learning_rate':
                np.zeros((args.max_nrof_epochs, ), np.float32),
                'time_train':
                np.zeros((args.max_nrof_epochs, ), np.float32),
                'time_validate':
                np.zeros((args.max_nrof_epochs, ), np.float32),
                'time_evaluate':
                np.zeros((args.max_nrof_epochs, ), np.float32),
                'prelogits_hist':
                np.zeros((args.max_nrof_epochs, 1000), np.float32),
            }
            for epoch in range(1, args.max_nrof_epochs + 1):
                step = sess.run(global_step, feed_dict=None)
                # Train for one epoch
                t = time.time()
                cont = train(
                    args, sess, epoch, image_list, label_list,
                    index_dequeue_op, enqueue_op, image_paths_placeholder,
                    labels_placeholder, learning_rate_placeholder,
                    phase_train_placeholder, batch_size_placeholder,
                    control_placeholder, global_step, total_loss, train_op,
                    summary_op, summary_writer, regularization_losses,
                    args.learning_rate_schedule_file, stat, cross_entropy_mean,
                    accuracy, learning_rate, prelogits, prelogits_center_loss,
                    args.random_rotate, args.random_crop, args.random_flip,
                    prelogits_norm, args.prelogits_hist_max,
                    args.use_fixed_image_standardization)
                stat['time_train'][epoch - 1] = time.time() - t

                if not cont:
                    break

                t = time.time()
                if len(val_image_list) > 0 and (
                    (epoch - 1) % args.validate_every_n_epochs
                        == args.validate_every_n_epochs - 1
                        or epoch == args.max_nrof_epochs):
                    validate(args, sess, epoch, val_image_list, val_label_list,
                             enqueue_op, image_paths_placeholder,
                             labels_placeholder, control_placeholder,
                             phase_train_placeholder, batch_size_placeholder,
                             stat, total_loss, regularization_losses,
                             cross_entropy_mean, accuracy,
                             args.validate_every_n_epochs,
                             args.use_fixed_image_standardization)
                stat['time_validate'][epoch - 1] = time.time() - t

                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, epoch)

                # Evaluate on LFW
                t = time.time()
                if args.lfw_dir:
                    evaluate(sess, enqueue_op, image_paths_placeholder,
                             labels_placeholder, phase_train_placeholder,
                             batch_size_placeholder, control_placeholder,
                             embeddings, label_batch, lfw_paths, actual_issame,
                             args.lfw_batch_size, args.lfw_nrof_folds, log_dir,
                             step, summary_writer, stat, epoch,
                             args.lfw_distance_metric, args.lfw_subtract_mean,
                             args.lfw_use_flipped_images,
                             args.use_fixed_image_standardization)
                stat['time_evaluate'][epoch - 1] = time.time() - t

                print('Saving statistics')
                with h5py.File(stat_file_name, 'w') as f:
                    for key, value in stat.items():
                        f.create_dataset(key, data=value)

    return model_dir
コード例 #17
0
ファイル: model.py プロジェクト: hologerry/pix2pix-flow
def abstract_model_xy(sess, hps, feeds, train_iterators, test_iterators, data_inits, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass
    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    if hps.joint_train:
        (loss_train_A, stats_train_A, eps_flatten_A, loss_train_B, stats_train_B, eps_flatten_B) \
            = f_loss(train_iterators, is_training=True)
    else:
        (loss_train_A, stats_train_A, loss_train_B, stats_train_B) \
            = f_loss(train_iterators, is_training=True)

    all_params = tf.trainable_variables()

    # Get train data op
    def get_train_data():
        x_A, y_A = train_iterators['A']()
        x_B, y_B = train_iterators['B']()
        return x_A, y_A, x_B, y_B
    m.get_train_data = get_train_data

    # A
    with tf.variable_scope('optim_A'):
        params_A = [param for param in all_params if 'A/' in param.name]
        if hps.gradient_checkpointing == 1:
            from memory_saving_gradients import gradients
            gs_A = gradients(loss_train_A, params_A)
        else:
            gs_A = tf.gradients(loss_train_A, params_A)
        m.optimizer_A = optim.Optimizer()
        train_op_A, polyak_swap_op_A, ema_A = m.optimizer_A.adamax(
            params_A, gs_A, alpha=lr, hps=hps)
        if hps.direct_iterator:
            m.train_A = lambda _lr: sess.run(
                [train_op_A, stats_train_A], {lr: _lr})[1]
        else:
            def _train_A(_lr, _x_A, _y_A, _x_B, _y_B):
                return sess.run([train_op_A, stats_train_A], {feeds['x_A']: _x_A,
                                                              feeds['y_A']: _y_A,
                                                              feeds['x_B']: _x_B,
                                                              feeds['y_B']: _y_B,
                                                              lr: _lr})[1]
            m.train_A = _train_A
        m.polyak_swap_A = lambda: sess.run(polyak_swap_op_A)
    # B
    with tf.variable_scope('optim_B'):
        params_B = [param for param in all_params if 'B/' in param.name]
        if hps.gradient_checkpointing == 1:
            from memory_saving_gradients import gradients
            gs_B = gradients(loss_train_B, params_B)
        else:
            gs_B = tf.gradients(loss_train_B, params_B)
        m.optimizer_B = optim.Optimizer()
        train_op_B, polyak_swap_op_B, ema_B = m.optimizer_B.adamax(
            params_B, gs_B, alpha=lr, hps=hps)
        if hps.direct_iterator:
            m.train_B = lambda _lr: sess.run(
                [train_op_B, stats_train_B], {lr: _lr})[1]
        else:
            def _train_B(_lr, _x_A, _y_A, _x_B, _y_B):
                return sess.run([train_op_B, stats_train_B], {feeds['x_A']: _x_A,
                                                              feeds['y_A']: _y_A,
                                                              feeds['x_B']: _x_B,
                                                              feeds['y_B']: _y_B,
                                                              lr: _lr})[1]
            m.train_B = _train_B
        m.polyak_swap_B = lambda: sess.run(polyak_swap_op_B)

    def _train(_lr, _x_A, _y_A, _x_B, _y_B):
        return sess.run([train_op_A, train_op_B, stats_train_A, stats_train_B],
                        {feeds['x_A']: _x_A, feeds['y_A']: _y_A,
                         feeds['x_B']: _x_B, feeds['y_B']: _y_B,
                         lr: _lr})[-2:]
    m.train = _train

    # === Testing
    loss_test_A, stats_test_A, loss_test_B, stats_test_B = f_loss(
        test_iterators, False, reuse=True)
    if hps.direct_iterator:
        m.test_A = lambda: sess.run(stats_test_A)
        m.test_B = lambda: sess.run(stats_test_B)
    else:
        # Get test data op
        def get_test_data():
            x_A, y_A = test_iterators['A']()
            x_B, y_B = test_iterators['B']()
            return x_A, y_A, x_B, y_B
        m.get_test_data = get_test_data

        def _test_A(_x_A, _y_A, _x_B, _y_B):
            return sess.run(stats_test_A, {feeds['x_A']: _x_A,
                                           feeds['y_A']: _y_A,
                                           feeds['x_B']: _x_B,
                                           feeds['y_B']: _y_B})

        def _test_B(_x_A, _y_A, _x_B, _y_B):
            return sess.run(stats_test_B, {feeds['x_A']: _x_A,
                                           feeds['y_A']: _y_A,
                                           feeds['x_B']: _x_B,
                                           feeds['y_B']: _y_B})
        m.test_A = _test_A
        m.test_B = _test_B

    # === Saving and restoring
    with tf.variable_scope('saver_A'):
        saver_A = tf.train.Saver()
        saver_ema_A = tf.train.Saver(ema_A.variables_to_restore())
        m.save_ema_A = lambda path_A: saver_ema_A.save(
            sess, path_A, write_meta_graph=False)
        m.save_A = lambda path_A: saver_A.save(
            sess, path_A, write_meta_graph=False)
        m.restore_A = lambda path_A: saver_A.restore(sess, path_A)

    with tf.variable_scope('saver_B'):
        saver_B = tf.train.Saver()
        saver_ema_B = tf.train.Saver(ema_B.variables_to_restore())
        m.save_ema_B = lambda path_B: saver_ema_B.save(
            sess, path_B, write_meta_graph=False)
        m.save_B = lambda path_B: saver_B.save(
            sess, path_B, write_meta_graph=False)
        m.restore_B = lambda path_B: saver_B.restore(sess, path_B)
        print("After saver")

    # === Initialize the parameters
    if hps.restore_path_A != '':
        m.restore_A(hps.restore_path_A)
    if hps.restore_path_B != '':
        m.restore_B(hps.restore_path_B)
    if hps.restore_path_A == '' and hps.restore_path_B == '':
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, False, reuse=True, init=True)

        all_params = tf.global_variables()
        params_A = [param for param in all_params if 'A/' in param.name]
        params_B = [param for param in all_params if 'B/' in param.name]
        sess.run(tf.variables_initializer(params_A))
        sess.run(tf.variables_initializer(params_B))
        feeds_dict = {feeds['x_A']: data_inits['A']['x'],
                      feeds['y_A']: data_inits['A']['y'],
                      feeds['x_B']: data_inits['B']['x'],
                      feeds['y_B']: data_inits['B']['y']}
        sess.run(results_init, feeds_dict)
    sess.run(hvd.broadcast_global_variables(0))

    return m
コード例 #18
0
        "segment_ids": tf.FixedLenFeature([128], tf.int64),
        "label_ids": tf.FixedLenFeature([], tf.int64),
    }

    params = Bunch({})
    params.epoch = epoch
    params.batch_size = 32
    jd_test = "/data/xuht/jd_comment/train.tfrecords"
    print(params["batch_size"], "===batch size===")
    input_fn = tf_data_utils.train_input_fn(jd_test,
                                            tf_data_utils._decode_record,
                                            name_to_features, params)

    sess = tf.Session(config=sess_config)

    init_op = tf.group(tf.local_variables_initializer())
    sess.run(init_op)

    sess.run(hvd.broadcast_global_variables(0))

    i = 0
    cnt = 0
    while True:
        try:
            features = sess.run(input_fn)
            i += 1
            cnt += 1
        except tf.errors.OutOfRangeError:
            print("End of dataset")
            break
    print(i)
def main(argv=None):
    '''
    '''
    main.__doc__ = __doc__
    argv = sys.argv if argv is None else sys.argv.extend(argv)
    desc = main.__doc__  # .format(os.path.basename(__file__))
    # CLI parser
    args = parser_(desc)

    nranks_per_gpu = args.nranks_per_gpu
    local_rank = hvd.local_rank()
    gpu_local_rank = local_rank // nranks_per_gpu
    print('local_rank, GPU_LOCAL_RANK: {}, {}'.format(
        local_rank, gpu_local_rank))

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.visible_device_list = str(gpu_local_rank)
    K.set_session(tf.Session(config=config))

    # input image dimensions
    img_rows, img_cols, img_chns = 28, 28, 1
    # number of convolutional filters to use
    filters = 64
    # convolution kernel size
    num_conv = 3

    hvdsize = hvd.size()

    batch_size = 128  # 100
    if K.image_data_format() == 'channels_first':
        original_img_size = (img_chns, img_rows, img_cols)
    else:
        original_img_size = (img_rows, img_cols, img_chns)
    latent_dim = 2
    intermediate_dim = 128
    epsilon_std = 1.0
    epochs = args.epochs  # 5

    # train the VAE on MNIST digits
    (x_train, _), (x_test, y_test) = mnist.load_data()

    x_train = x_train.astype('float32') / 255.
    x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
    x_test = x_test.astype('float32') / 255.
    x_test = x_test.reshape((x_test.shape[0],) + original_img_size)

    if hvd.rank() == 0:
        print('x_train.shape:', x_train.shape)

    train_samples = x_train.shape[0]
    # steps_per_epoch = train_samples // batch_size // hvdsize
    speedupopt = args.speedup
    if speedupopt == SpeedupOpts.imgspersec:
        steps_per_epoch = train_samples // batch_size
    else:
        steps_per_epoch = int(round(
            float(train_samples) / batch_size / hvdsize + 0.5))

    # Create the dataset and its associated one-shot iterator.
    buffer_size = 10000
    dataset = Dataset.from_tensor_slices(x_train)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    x_train_batch = iterator.get_next()

    ldict = make_shared_layers_dict(
        img_chns, img_rows, img_cols, batch_size, filters,
        num_conv, intermediate_dim, latent_dim, epsilon_std)
    # ldict is a dictionary that holds all layers. Since these layers are
    # instantiated once, they are shared amongs vae, encoder, and generator.

    x = Input(tensor=x_train_batch)
    vae = make_vae(ldict, x)
    # :  :type vae: Model

    lr = 0.001  # * hvdsize
    opt = tf.train.RMSPropOptimizer(lr)
    # Add Horovod Distributed Optimizer.
    opt = hvd.DistributedOptimizer(opt)  # , use_locking=True)
    opt = TFOptimizer(opt)

    # opt = RMSprop(lr)
    # Add Horovod Distributed Optimizer.
    # opt = hvd_keras.DistributedOptimizer(opt)  # , use_locking=True)

    vae.compile(optimizer=opt, loss=None)
    if hvd.rank() == 0:
        vae.summary()

    callbacks = []
    if hvd.rank() == 0:
        callbacks += [BatchTiming(), SamplesPerSec(batch_size * hvdsize)]

    sess = K.get_session()
    sess.run(hvd.broadcast_global_variables(0))

    # Fit the model using data from the TF data tensors.
    vae.fit(steps_per_epoch=steps_per_epoch, epochs=epochs,
            callbacks=callbacks)

    if hvd.rank() == 0:
        x = Input(shape=original_img_size)
        vae_val = make_vae(ldict, x)
        vae_val.compile(optimizer=opt, loss=None)
        loss = vae_val.evaluate(x=x_test, y=None, batch_size=batch_size)
        print('\n\nVAE VALIDATION LOSS: {}'.format(loss))

        x = Input(shape=original_img_size)
        z_mean, _ = get_encoded(ldict, x)
        encoder = Model(x, z_mean)
        # :  :type encoder: Model

        decoder_input = Input(shape=(latent_dim,))
        x_decoded_mean_squash = get_decoded(ldict, decoder_input)
        generator = Model(decoder_input, x_decoded_mean_squash)
        # :  :type generator: Model

        # display a 2D plot of the digit classes in the latent space
        x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
        plt.figure(figsize=(6, 6))
        plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
        plt.colorbar()
        # plt.show()
        plt.savefig('vae_scatter.ps')
        plt.close()

        # display a 2D manifold of the digits
        n = 15  # figure with 15x15 digits
        digit_size = 28
        figure = np.zeros((digit_size * n, digit_size * n))
        # Linearly spaced coordinates on the unit square were transformed
        # through the inverse CDF (ppf) of the Gaussian
        # To produce values of the latent variables z, since the prior of the
        # latent space is Gaussian
        grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
        grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

        for i, yi in enumerate(grid_x):
            for j, xi in enumerate(grid_y):
                z_sample = np.array([[xi, yi]])
                z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
                x_decoded = generator.predict(z_sample, batch_size=batch_size)
                digit = x_decoded[0].reshape(digit_size, digit_size)
                figure[i * digit_size: (i + 1) * digit_size,
                       j * digit_size: (j + 1) * digit_size] = digit

        plt.figure(figsize=(10, 10))
        plt.imshow(figure, cmap='Greys_r')
        # plt.show()
        plt.savefig('vae_digit.ps')
        plt.close()

    K.clear_session()
コード例 #20
0
ファイル: __init__.py プロジェクト: fightseed/horovod-1
 def broadcast_global_variables(backend, root_rank):
     return _eval(backend, hvd.broadcast_global_variables(root_rank))
def main(argv=None):
    '''
    '''
    main.__doc__ = __doc__
    argv = sys.argv if argv is None else sys.argv.extend(argv)
    desc = main.__doc__  # .format(os.path.basename(__file__))
    # CLI parser
    args = parser_(desc)

    nranks_per_gpu = args.nranks_per_gpu
    local_rank = hvd.local_rank()
    gpu_local_rank = local_rank // nranks_per_gpu
    print('local_rank, GPU_LOCAL_RANK: {}, {}'.format(
        local_rank, gpu_local_rank))

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.visible_device_list = str(gpu_local_rank)
    K.set_session(tf.Session(config=config))

    # input image dimensions
    img_rows, img_cols, img_chns = 28, 28, 1
    # number of convolutional filters to use
    filters = 64
    # convolution kernel size
    num_conv = 3

    hvdsize = hvd.size()

    batch_size = 128  # 100
    if K.image_data_format() == 'channels_first':
        original_img_size = (img_chns, img_rows, img_cols)
    else:
        original_img_size = (img_rows, img_cols, img_chns)
    latent_dim = 2
    intermediate_dim = 128
    epsilon_std = 1.0
    epochs = args.epochs  # 5

    # train the VAE on MNIST digits
    (x_train, _), (x_test, y_test) = mnist.load_data()

    # Data split if going for reduction in each iteration step. Using
    # tf-queue or dataset is better to preserve uniform random sampling.
    # nsamples = x_train.shape[0]
    # mysamples = nsamples // hvdsize
    # start_sam = hvd.local_rank() * mysamples
    # stop_sam = min((hvd.local_rank() + 1) * mysamples, nsamples)
    # x_train = x_train[start_sam:stop_sam, ...]

    x_train = x_train.astype('float32') / 255.
    x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
    x_test = x_test.astype('float32') / 255.
    x_test = x_test.reshape((x_test.shape[0],) + original_img_size)

    if hvd.rank() == 0:
        print('x_train.shape:', x_train.shape)

    vae, encoder, generator = make_vae_and_codec(
        original_img_size, img_chns, img_rows, img_cols, batch_size,
        filters, num_conv, intermediate_dim, latent_dim, epsilon_std)
    # :  :type vae: Model

    lr = 0.001  # * hvdsize
    opt = tf.train.RMSPropOptimizer(lr)
    # Add Horovod Distributed Optimizer.
    opt = hvd.DistributedOptimizer(opt)  # , use_locking=True)
    opt = TFOptimizer(opt)

    vae.compile(optimizer=opt, loss=None)
    if hvd.rank() == 0:
        vae.summary()

    callbacks = []
    if hvd.rank() == 0:
        callbacks += [BatchTiming(), SamplesPerSec(batch_size * hvdsize)]

    sess = K.get_session()
    sess.run(hvd.broadcast_global_variables(0))

    vae.fit(x_train,
            shuffle=True,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
            callbacks=callbacks)

    if hvd.rank() == 0:
        vae_val = vae
        loss = vae_val.evaluate(x=x_test, y=None, batch_size=batch_size)
        print('\n\nVAE VALIDATION LOSS: {}'.format(loss))

        # display a 2D plot of the digit classes in the latent space
        x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
        plt.figure(figsize=(6, 6))
        plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
        plt.colorbar()
        # plt.show()
        plt.savefig('vae_scatter.ps')
        plt.close()

        # display a 2D manifold of the digits
        n = 15  # figure with 15x15 digits
        digit_size = 28
        figure = np.zeros((digit_size * n, digit_size * n))
        # Linearly spaced coordinates on the unit square were transformed
        # through the inverse CDF (ppf) of the Gaussian
        # To produce values of the latent variables z, since the prior of the
        # latent space is Gaussian
        grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
        grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

        for i, yi in enumerate(grid_x):
            for j, xi in enumerate(grid_y):
                z_sample = np.array([[xi, yi]])
                z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
                x_decoded = generator.predict(z_sample, batch_size=batch_size)
                digit = x_decoded[0].reshape(digit_size, digit_size)
                figure[i * digit_size: (i + 1) * digit_size,
                       j * digit_size: (j + 1) * digit_size] = digit

        plt.figure(figsize=(10, 10))
        plt.imshow(figure, cmap='Greys_r')
        # plt.show()
        plt.savefig('vae_digit.ps')
        plt.close()

    K.clear_session()
コード例 #22
0
def tf_train_flow(
        train_once_fn,
        model_dir=None,
        log_dir=None,
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=None,
        freeze_graph=False,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        valid_interval_epochs=0,
        inference_fn=None,
        inference_interval_epochs=0,
        init_fn=None,
        restore_fn=None,
        restore_include=None,
        restore_exclude=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        output_collection_names=None,
        output_node_names=None,
        learning_rate=None,  #not use yet, just use in train_once
        learning_rate_patience=None,
        learning_rate_decay_factor=None,
        write_during_train=True,
        model=None,
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    model_dir_ = model_dir
    if use_horovod and hvd.rank() != 0:
        model_dir = None

    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()

    if FLAGS.use_tpu:
        sess.run(tpu.initialize_system())
    #logging.info('tf_train_flow start')
    #logging.info('max_models_keep:', max_models_keep)
    #logging.info('save_interval_seconds:', save_interval_seconds)

    if model_dir:
        if model:
            checkpoint = tf.train.Checkpoint(model=model)
            ckpt_dir = model_dir + '/ckpt'
            checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

        #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
        #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

        # var_list = None if not restore_scope else tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
        # #logging.info('-------------var_list', var_list)

        # if not variables_to_restore:
        #   variables_to_restore = var_list

        if not variables_to_restore:
            variables_to_restore = slim.get_variables_to_restore(
                include=restore_include, exclude=restore_exclude)

        if not variables_to_save:
            variables_to_save = variables_to_restore
        if save_all_scope:
            variables_to_save = None

        #if variables_to_restore is None:
        logging.info('variables_to_restore from %s' % model_dir)
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))

        # TODO has someproblem say  tf.Variable 'r_net/text_encoder/cudnn_rnn/cu_dnngru/recurrent_kernel/adam_v:0' even though in checkpoint I have renated it as ignore/rnet
        variables_to_restore_from_model = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)
        #logging.info('variables_to_restore_from_model: {}'.format(variables_to_restore_from_model))
        if not variables_to_restore:
            variables_to_restore = variables_to_restore_from_model
        else:
            variables_to_restore = [
                v for v in variables_to_restore
                if v in variables_to_restore_from_model
            ]
        if restore_exclude:
            for excl in restore_exclude:
                variables_to_restore = [
                    v for v in variables_to_restore if not excl in v.name
                ]
        #--tf 1.6 adadelta will have same vars...
        variables_to_restore = list(set(variables_to_restore))
        #logging.info('variables_to_restore', variables_to_restore[:100])
        logging.info('variables_to_restore', [
            x for x in variables_to_restore if not 'OptimizeLoss' in x.name
        ][:100])

    ##finally remove global_step since melt.apps.train will handle it!
    global_step = tf.train.get_or_create_global_step()

    #variables_to_restore = [v for v in variables_to_restore if not tf.GraphKeys.GLOBAL_STEP in v.name]
    #variables_to_restore = [v for v in variables_to_restore if not 'learning_rate' in v.name]

    # TODO fixme if step, step2.. and in checkpoint step then here will be step2...
    #print('------------', [v for v in variables_to_restore if 'step' in v.name])
    loader = tf.train.Saver(var_list=variables_to_restore)

    logging.info('max models to keep {}, keep every {} hours'.format(
        max_models_keep, save_interval_seconds / 3600.0))
    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)
    #logging.info('variables_to_save:{}'.format(variables_to_save))

    # # #TODO for safe restore all init will be ok ?
    # # if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # # else:
    # #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    # #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    # https://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables
    # def guarantee_initialized_variables(session, list_of_variables = None):
    #   if list_of_variables is None:
    #       list_of_variables = tf.global_variables()
    #   uninitialized_variables = list(tf.get_variable(name) for name in
    #                                  session.run(tf.report_uninitialized_variables(list_of_variables)))
    #   return unintialized_variables

    # unintialized_variables = guarantee_initialized_variables(sess)
    # init_op = tf.group(tf.initialize_variables(uninitialized_vars), tf.local_variables_initializer())

    timer = gezi.Timer('sess run init_op in melt.tf_train_flow')
    #model.save('./weights')

    # notice
    sess.run(init_op)

    timer.print_elapsed()

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if you change batch size
    #print(model_dir)
    pre_epoch = None
    if model_dir:
        model_path = _get_model_path(model_dir, save_model)
        # if not model_path:
        #   model_path = _get_model_path(os.path.join(model_dir, 'epoch'))
        #print(model_path)
        model_dir = gezi.get_dir(
            model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model

        if model_path is not None:
            if not restore_from_latest:
                logging.info('using recent but not latest model')
                model_path = melt.recent_checkpoint(model_dir)
            model_name = os.path.basename(model_path)
            timer = gezi.Timer(
                'Loading and training from existing model [%s]' % model_path)
            if restore_fn is not None:
                restore_fn(sess)
            loader.restore(sess, model_path)
            ## not supported
            #model.save()
            #model.save_weights('./weights')
            timer.print()
            #pre_step = melt.get_model_step(model_path) - 1 if FLAGS.global_step is None else FLAGS.global_step -1
            # TODO check ..
            pre_step = sess.run(tf.train.get_global_step()) - 1
            pre_epoch = melt.get_model_epoch(
                model_path
            ) if FLAGS.global_epoch is None else FLAGS.global_epoch
            fixed_pre_step = pre_step
            # if pre_epoch is not None:
            #   #like using batch size 32, then reload train using batch size 64
            #   if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
            #     fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
            #     logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
            #       .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
        else:
            latest_checkpoint = None
            if not use_horovod:  #now will hang
                try:
                    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
                    if latest_checkpoint:
                        logging.info(
                            'Try start from eager trained mode, latest checkpoint:',
                            latest_checkpoint)
                        checkpoint.restore(latest_checkpoint).run_restore_ops(
                            session=sess)

                        pre_epoch = int(latest_checkpoint.split('-')[-1])
                        #pre_step = pre_epoch * num_steps_per_epoch - 1
                        # TODO check
                        pre_step = sess.run(tf.train.get_global_step()) - 1
                        fixed_pre_step = pre_step
                        logging.info('Start step is:', pre_step)
                except Exception:
                    logging.info(
                        'Something wrong with restore from eager trained model'
                    )
                if latest_checkpoint is None:
                    logging.info('Train all start step 0')
                    #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
                    #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
                    #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
                    #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
                    #init_op = tf.group(tf.global_variables_initializer(),
                    #                   tf.local_variables_initializer())
                    #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

                    #sess.run(init_op)

                    #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
                    #for finetune from loading other model init
                    if init_fn is not None:
                        init_fn(sess)

    if gezi.env_has('METRIC'):
        l = metric_eval_fn(model_path)
        print(list(zip(l[1], l[0])))
        exit(0)

    #sess.run(tf.assign(global_step, tf.constant(global_step_val, dtype=tf.int64)))
    try:
        learning_rate = tf.get_collection('learning_rate')[-1]
        learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
        sess.run(tf.assign(learning_rate,
                           learning_rate * learning_rate_weight))
    except Exception:
        # if not using weight_decay but using optimizer decay then will go here as learning rate is a tensor can not assign
        pass

    try:
        logging.info('Actual start global step:',
                     sess.run(global_step), 'learning rate:',
                     sess.run(learning_rate), 'learning_rate_weight:',
                     sess.run(learning_rate_weight))
    except Exception:
        pass

    if model_dir_:
        #if save_interval_epochs and num_steps_per_epoch and num_steps >= 0:
        epoch_dir = os.path.join(model_dir_, 'epoch')
        gezi.try_mkdir(epoch_dir)
        checkpoint_path = os.path.join(model_dir_, 'model.ckpt')

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    if use_horovod:
        bcast = hvd.broadcast_global_variables(0)
        sess.run(bcast)

    #tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        if use_horovod:
            ## TODO FIXME why bcast here not work ? simple test work see tests/bcast.py
            #comm.bcast(pre_step, root=0)
            temp = np.array([pre_step, fixed_pre_step])
            comm.Bcast(temp, root=0)
            pre_step = temp[0]
            fixed_pre_step = temp[1]

        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1

        #first = True

        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            logging.info('just load and resave then exit')
            model_path_ = _get_checkpoint_path(checkpoint_path,
                                               step,
                                               num_steps_per_epoch,
                                               epoch=pre_epoch)
            saver.save(sess, model_path_, global_step=step + 1)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step + 1,
                                  output_collection_names, output_node_names)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            logging.info('just run one step')

        if FLAGS.work_mode != 'train':
            assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
            if 'valid' in FLAGS.work_mode:
                vals, names = metric_eval_fn(FLAGS.model_dir)
                logging.info(list(zip(names, vals)))
            if 'test' in FLAGS.work_mode:
                inference_fn(FLAGS.model_dir)
            exit(0)

        #early_stop = True #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        epoch_saved_step = 0
        while not coord.should_stop():
            model_step_path = None
            if model_dir_:
                model_path_ = os.path.join(
                    epoch_dir, 'model.ckpt-%.2f' %
                    (fixed_step / float(num_steps_per_epoch)))
                model_step_path_ = model_path_ + '-' + str(step)
                if (write_during_train and metric_eval_fn is not None
                        and valid_interval_epochs and fixed_step %
                        int(num_steps_per_epoch * valid_interval_epochs) == 0):
                    model_step_path = model_step_path_
                else:
                    model_step_path = None

            if step == 0:
                model_step_path = None

            #print('--------------------step', step)
            stop = train_once_fn(
                sess,
                step,
                is_start=(step == start),
                fixed_step=fixed_step,
                num_epochs=num_epochs,
                model_path=model_step_path,
                use_horovod=use_horovod,
                ## TODO FIXME this line will cause   tensorflow.python.framework.errors_impl.NotFoundError: Resource localhost/save_counter/N10tensorflow3VarE does not exist.
            )

            #first = False

            if only_one_step:
                stop = True

            step += 1
            fixed_step += 1

            if save_model and step and model_dir:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer(
                        'save model step %d to %s' % (step, checkpoint_path),
                        False)
                    model_path_ = _get_checkpoint_path(checkpoint_path,
                                                       fixed_step,
                                                       num_steps_per_epoch)
                    saver.save(sess, model_path_, global_step=step)
                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)
                    #if log_dir != model_dir:
                    #  assert log_dir
                    #  command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                    #  print(command, file=sys.stderr)
                    #  os.system(command)
                    timer.print_elapsed()

                if save_interval_steps and num_steps_per_epoch and fixed_step % int(
                        num_steps_per_epoch * save_interval_epochs) == 0:
                    # TODO only epoch in name not sep ?
                    epoch_saved_step = step
                    model_path_ = os.path.join(
                        epoch_dir, 'model.ckpt-%.2f' %
                        (fixed_step / float(num_steps_per_epoch)))
                    model_step_path = model_path_ + '-' + str(step)
                    epoch_saver.save(sess, model_path_, global_step=step)
                    #epoch_saver.save(sess, model_path_)

                    ## TODO FIXME do not support tf.keras save currently with horovod
                    # if model:
                    #   #model.save_weights(epoch_dir + '/ckpt-%.2f' % (fixed_step / float(num_steps_per_epoch)))
                    #   # TODO FIXME if restart will save from 1... again..
                    #   checkpoint.save(checkpoint_prefix, session=sess)
                    #   #print(sess.run(checkpoint.save_counter))

                    if freeze_graph:
                        melt.freeze_graph(sess, model_path_, step,
                                          output_collection_names,
                                          output_node_names)

                if write_during_train:
                    if inference_fn is not None and inference_interval_epochs and fixed_step % int(
                            num_steps_per_epoch *
                            inference_interval_epochs) == 0:
                        model_step_path = model_path_ + '-' + str(step)
                        try:
                            #print('--------------inference fn')
                            inference_fn(model_path=model_step_path)
                        except Exception:
                            logging.info(traceback.format_exc())

                    # if metric_eval_fn is not None and valid_interval_epochs and fixed_step % int(num_steps_per_epoch * valid_interval_epochs) == 0:
                    #   model_step_path = model_path_ + '-' + str(step)
                    #   try:
                    #     metric_eval_fn(model_path=model_step_path)
                    #   except Exception:
                    #     logging.info(traceback.format_exc())

            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            #if max_num_epochs and num_steps_per_epoch and fixed_step // num_steps_per_epoch >= max_num_epochs:
            if max_num_epochs and num_steps_per_epoch and fixed_step / num_steps_per_epoch > max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
    #except tf.errors.OutOfRangeError, e:
    except tf.errors.OutOfRangeError:
        # if run 2 epoch and we have just epoch saved, do not need to save only 1 step more model
        if (step - epoch_saved_step > 1) and not (
                step == start
        ) and save_model and step % save_interval_steps != 0 and model_dir:
            model_path_ = _get_checkpoint_path(checkpoint_path, step,
                                               num_steps_per_epoch)
            saver.save(sess, model_path_, global_step=step)
            if freeze_graph:
                melt.freeze_graph(sess, model_path_, step,
                                  output_collection_names, output_node_names)
            if log_dir != model_dir:
                assert log_dir
                command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
                print(command, file=sys.stderr)
                os.system(command)
        if only_one_step:
            logging.info('Done one step')
            exit(0)

        # if (step - epoch_saved_step > 1) and metric_eval_fn is not None:
        #   metric_eval_fn(model_path=model_step_path)

        if (num_epochs and fixed_step / num_steps_per_epoch >= num_epochs) or (
                num_steps and step == start + num_steps):
            logging.info('Done training for %.3f epochs, %d steps.' %
                         (fixed_step / num_steps_per_epoch, step))
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            logging.info('Should not stop, but stopped at epoch: %.3f' %
                         (fixed_step / num_steps_per_epoch))
            logging.info(traceback.format_exc())
            #raise e
    finally:
        coord.request_stop()

    coord.join(threads, stop_grace_period_secs=5)
    #FIMXE due to use melt.get_session(global not handle del well)
    #Done training for 3090020 steps.
    #Exception TypeError: "'NoneType' object is not callable" in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x7f6cf33cd450>> ignored
    if FLAGS.use_tpu:
        sess.run(tpu.shutdown_system())
    sess.close()
コード例 #23
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
コード例 #24
0
    # Compute the cosine similarity between minibatch examples and all embeddings.
    norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
    normalized_embeddings = embeddings / norm
    valid_embeddings = tf.nn.embedding_lookup(
        normalized_embeddings, valid_dataset)
    similarity = tf.matmul(
        valid_embeddings, normalized_embeddings, transpose_b=True)

    # Add variable initializer.
    init = tf.global_variables_initializer()

    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    bcast = hvd.broadcast_global_variables(0)

# Step 5: Begin training.

# Horovod: adjust number of steps based on number of GPUs.
num_steps = 100000 // hvd.size() + 1

# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())

with tf.Session(graph=graph, config=config) as session:
    # We must initialize all variables before we use them.
    init.run()
    bcast.run()
コード例 #25
0
def start_training(config):
    if config.IS_DISTRIBUTION:
        import horovod.tensorflow as hvd
        # initialize Horovod.
        hvd.init()
        num_worker = hvd.size()
        rank = hvd.rank()
        # verify that MPI multi-threading is supported.
        assert hvd.mpi_threads_supported()
        # make sure MPI is not re-initialized.
        import mpi4py.rc
        mpi4py.rc.initialize = False
        # import mpi4py
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        # check size and rank are syncronized
        assert num_worker == comm.Get_size()
        assert rank == comm.Get_rank()
    else:
        num_worker = 1
        rank = 0

    ModelClass = config.NETWORK_CLASS
    network_kwargs = dict(
        (key.lower(), val) for key, val in config.NETWORK.items())
    if "train_validation_saving_size".upper() in config.DATASET.keys():
        use_train_validation_saving = config.DATASET.TRAIN_VALIDATION_SAVING_SIZE > 0
    else:
        use_train_validation_saving = False

    if use_train_validation_saving:
        top_train_validation_saving_set_accuracy = 0

    train_dataset = setup_dataset(config, "train", rank)
    print("train dataset num:", train_dataset.num_per_epoch)

    if use_train_validation_saving:
        train_validation_saving_dataset = setup_dataset(
            config, "train_validation_saving", rank)
        print("train_validation_saving dataset num:",
              train_validation_saving_dataset.num_per_epoch)

    validation_dataset = setup_dataset(config, "validation", rank)
    print("validation dataset num:", validation_dataset.num_per_epoch)

    graph = tf.Graph()
    with graph.as_default():
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            model = ModelClass(
                classes=train_dataset.classes,
                num_max_boxes=train_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        elif ModelClass.__module__.startswith("lmnet.networks.segmentation"):
            model = ModelClass(
                classes=train_dataset.classes,
                label_colors=train_dataset.label_colors,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        else:
            model = ModelClass(
                classes=train_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        global_step = tf.Variable(0, name="global_step", trainable=False)
        is_training_placeholder = tf.placeholder(
            tf.bool, name="is_training_placeholder")

        images_placeholder, labels_placeholder = model.placeholderes()

        output = model.inference(images_placeholder, is_training_placeholder)
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            loss = model.loss(output, labels_placeholder,
                              is_training_placeholder)
        else:
            loss = model.loss(output, labels_placeholder)
        opt = model.optimizer(global_step)
        if config.IS_DISTRIBUTION:
            # add Horovod Distributed Optimizer
            opt = hvd.DistributedOptimizer(opt)
        train_op = model.train(loss, opt, global_step)
        metrics_ops_dict, metrics_update_op = model.metrics(
            output, labels_placeholder)
        # TODO(wakisaka): Deal with many networks.
        model.summary(output, labels_placeholder)

        summary_op = tf.summary.merge_all()

        metrics_summary_op, metrics_placeholders = executor.prepare_metrics(
            metrics_ops_dict)

        init_op = tf.global_variables_initializer()
        reset_metrics_op = tf.local_variables_initializer()
        if config.IS_DISTRIBUTION:
            # add Horovod broadcasting variables from rank 0 to all
            bcast_global_variables_op = hvd.broadcast_global_variables(0)

        if use_train_validation_saving:
            saver = tf.train.Saver(max_to_keep=1)
        else:
            saver = tf.train.Saver(max_to_keep=None)

        if config.IS_PRETRAIN:
            all_vars = tf.global_variables()
            pretrain_var_list = [
                var for var in all_vars
                if var.name.startswith(tuple(config.PRETRAIN_VARS))
            ]
            print("pretrain_vars", [var.name for var in pretrain_var_list])
            pretrain_saver = tf.train.Saver(pretrain_var_list,
                                            name="pretrain_saver")

    if config.IS_DISTRIBUTION:
        # For distributed training
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, visible_device_list=str(hvd.local_rank())))
    else:
        # TODO(wakisaka): For debug.
        # session_config = tf.ConfigProto(
        #     gpu_options=tf.GPUOptions(
        #         allow_growth=True,
        #         per_process_gpu_memory_fraction=0.1
        #     )
        # )
        session_config = tf.ConfigProto(
        )  # tf.ConfigProto(log_device_placement=True)
    # TODO(wakisaka): XLA JIT
    # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    sess = tf.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])

    if rank == 0:
        train_writer = tf.summary.FileWriter(
            environment.TENSORBOARD_DIR + "/train", sess.graph)
        if use_train_validation_saving:
            train_val_saving_writer = tf.summary.FileWriter(
                environment.TENSORBOARD_DIR + "/train_validation_saving")
        val_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR +
                                           "/validation")

        if config.IS_PRETRAIN:
            print("------- Load pretrain data ----------")
            pretrain_saver.restore(
                sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE))
            sess.run(tf.assign(global_step, 0))

        last_step = 0

        # for recovery
        ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            print("--------- Restore last checkpoint -------------")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # saver.recover_last_checkpoints(ckpt.model_checkpoint_path)
            last_step = sess.run(global_step)
            # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard.
            # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072
            train_writer.add_session_log(SessionLog(status=SessionLog.START),
                                         global_step=last_step + 1)
            val_writer.add_session_log(SessionLog(status=SessionLog.START),
                                       global_step=last_step + 1)
            print("recovered. last step", last_step)

    if config.IS_DISTRIBUTION:
        # broadcast variables from rank 0 to all other processes
        sess.run(bcast_global_variables_op)
        # calculate step per epoch for each nodes
        train_num_per_epoch = train_dataset.num_per_epoch
        num_per_nodes = (train_num_per_epoch + num_worker - 1) // num_worker
        step_per_epoch = num_per_nodes // config.BATCH_SIZE
        begin_index = (train_num_per_epoch * rank) // num_worker
        end_index = begin_index + num_per_nodes

    last_step = sess.run(global_step)

    # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS.
    if "MAX_EPOCHS" in config:
        max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE *
                        config.MAX_EPOCHS)
    else:
        max_steps = config.MAX_STEPS
    print("max_steps: {}".format(max_steps))

    for step in range(last_step, max_steps):
        print("step", step)

        if config.IS_DISTRIBUTION:
            # scatter dataset
            if step % step_per_epoch == 0:
                indices = train_dataset.get_shuffle_index(
                ) if rank == 0 else None
                # broadcast shuffled indices
                indices = comm.bcast(indices, 0)
                feed_indices = indices[begin_index:end_index]
                # update each dataset by splited indices
                train_dataset.update_dataset(feed_indices)

        images, labels = train_dataset.feed()

        feed_dict = {
            is_training_placeholder: True,
            images_placeholder: images,
            labels_placeholder: labels,
        }

        if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0:
            # Runtime statistics for develop.
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()

            sess.run(reset_metrics_op)
            _, summary, _ = sess.run(
                [train_op, summary_op, metrics_update_op],
                feed_dict=feed_dict,
                # options=run_options,
                # run_metadata=run_metadata,
            )
            # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1))
            train_writer.add_summary(summary, step + 1)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }

            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            train_writer.add_summary(metrics_summary, step + 1)
        else:
            sess.run([train_op], feed_dict=feed_dict)

        to_be_saved = step == 0 or (
            step + 1) == max_steps or (step + 1) % config.SAVE_STEPS == 0

        if to_be_saved and rank == 0:
            if use_train_validation_saving:

                sess.run(reset_metrics_op)
                train_validation_saving_step_size = int(
                    math.ceil(train_validation_saving_dataset.num_per_epoch /
                              config.BATCH_SIZE))
                print("train_validation_saving_step_size",
                      train_validation_saving_step_size)

                current_train_validation_saving_set_accuracy = 0

                for train_validation_saving_step in range(
                        train_validation_saving_step_size):
                    print("train_validation_saving_step",
                          train_validation_saving_step)

                    images, labels = train_validation_saving_dataset.feed()
                    feed_dict = {
                        is_training_placeholder: False,
                        images_placeholder: images,
                        labels_placeholder: labels,
                    }

                    if train_validation_saving_step % config.SUMMARISE_STEPS == 0:
                        summary, _ = sess.run([summary_op, metrics_update_op],
                                              feed_dict=feed_dict)
                        train_val_saving_writer.add_summary(summary, step + 1)
                    else:
                        sess.run([metrics_update_op], feed_dict=feed_dict)

                metrics_values = sess.run(list(metrics_ops_dict.values()))
                metrics_feed_dict = {
                    placeholder: value
                    for placeholder, value in zip(metrics_placeholders,
                                                  metrics_values)
                }
                metrics_summary, = sess.run(
                    [metrics_summary_op],
                    feed_dict=metrics_feed_dict,
                )
                train_val_saving_writer.add_summary(metrics_summary, step + 1)

                current_train_validation_saving_set_accuracy = sess.run(
                    metrics_ops_dict["accuracy"])

                if current_train_validation_saving_set_accuracy > top_train_validation_saving_set_accuracy:
                    top_train_validation_saving_set_accuracy = current_train_validation_saving_set_accuracy
                    print("New top train_validation_saving accuracy is: ",
                          top_train_validation_saving_set_accuracy)

                    _save_checkpoint(saver, sess, global_step, step)

            else:
                _save_checkpoint(saver, sess, global_step, step)

            if step == 0:
                # check create pb on only first step.
                minimal_graph = tf.graph_util.convert_variables_to_constants(
                    sess,
                    sess.graph.as_graph_def(add_shapes=True),
                    ["output"],
                )
                pb_name = "minimal_graph_with_shape_{}.pb".format(step + 1)
                pbtxt_name = "minimal_graph_with_shape_{}.pbtxt".format(step +
                                                                        1)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pb_name,
                                     as_text=False)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pbtxt_name,
                                     as_text=True)

        if step == 0 or (step + 1) % config.TEST_STEPS == 0:
            # init metrics values
            sess.run(reset_metrics_op)
            test_step_size = int(
                math.ceil(validation_dataset.num_per_epoch /
                          config.BATCH_SIZE))
            print("test_step_size", test_step_size)

            for test_step in range(test_step_size):
                print("test_step", test_step)

                images, labels = validation_dataset.feed()
                feed_dict = {
                    is_training_placeholder: False,
                    images_placeholder: images,
                    labels_placeholder: labels,
                }

                if test_step % config.SUMMARISE_STEPS == 0:
                    summary, _ = sess.run([summary_op, metrics_update_op],
                                          feed_dict=feed_dict)
                    if rank == 0:
                        val_writer.add_summary(summary, step + 1)
                else:
                    sess.run([metrics_update_op], feed_dict=feed_dict)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }
            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            if rank == 0:
                val_writer.add_summary(metrics_summary, step + 1)

    # training loop end.
    print("reach max step")
コード例 #26
0
def main(argv=None):
    '''
    '''
    main.__doc__ = __doc__
    argv = sys.argv if argv is None else sys.argv.extend(argv)
    desc = main.__doc__  # .format(os.path.basename(__file__))
    # CLI parser
    args = parser_(desc)

    # Initialize Horovod.
    hvd.init()

    logdevp = args.logdevp  # For debugging
    log_device_placement, allow_soft_placement = (True, True) \
        if _DEVPROF or logdevp else (False, False)

    nranks_per_gpu = args.nranks_per_gpu
    local_rank = hvd.local_rank()
    gpu_local_rank = local_rank // nranks_per_gpu
    print('local_rank, GPU_LOCAL_RANK: {}, {}'.format(
        local_rank, gpu_local_rank))

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto(log_device_placement=log_device_placement,
                            allow_soft_placement=allow_soft_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.visible_device_list = str(gpu_local_rank)
    KB.set_session(tf.Session(config=config))

    hvdsize = hvd.size()

    checkpt = getattr(args, 'checkpt', None)
    checkpt_flag = False if checkpt is None else True
    filepath = checkpt
    # print('CHECKPT:', checkpt)

    batch_size = args.batch_size
    num_classes = 10
    epochs = args.epochs
    data_augmentation = args.aug

    datadir = getattr(args, 'datadir', None)

    # The data, shuffled and split between train and test sets:
    (x_train, y_train), (x_test, y_test) = cifar10_load_data(datadir) \
        if datadir is not None else cifar10.load_data()
    train_samples = x_train.shape[0]
    test_samples = x_test.shape[0]
    steps_per_epoch = train_samples // batch_size // hvdsize
    test_batches = test_samples // batch_size
    print(train_samples, 'train samples')
    print(test_samples, 'test samples')

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255

    # Convert class vectors to binary class matrices.
    y_train = to_categorical(y_train, num_classes).squeeze()
    y_test = to_categorical(y_test, num_classes).squeeze()

    callbacks = []
    if hvd.rank() == 0:
        callbacks += [BatchTiming(), SamplesPerSec(batch_size * hvdsize)]

    print(x_train.shape, 'train shape')
    # with tf.device('/cpu:0'):
    model = make_model(x_train.shape, num_classes,
                       filepath if checkpt_flag else None)

    lr = 0.0001 * hvdsize
    opt = tf.train.RMSPropOptimizer(lr)
    # Add Horovod Distributed Optimizer.
    opt = hvd.DistributedOptimizer(opt)  # , use_locking=True)
    opt = TFOptimizer(opt)  # Required for tf.train based optimizers

    # ------------------------------------- HAVE TO GET SESSION AFTER OPTIMIZER
    # sess = KB.get_session()
    # -------------------------------------------------------------------------

    # Let's train the model using RMSprop
    model.compile(loss='categorical_crossentropy',
                  optimizer=opt,
                  metrics=['accuracy'])
    if hvd.rank() == 0:
        model.summary()

    KB.get_session().run(hvd.broadcast_global_variables(0))
    if not data_augmentation:
        print('Not using data augmentation.')
        # model.fit(x_train, y_train,
        #           batch_size=batch_size,
        #           epochs=epochs,
        #           validation_data=(x_test, y_test),
        #           shuffle=True,
        #           callbacks=callbacks)

        train_gen = ImageDataGenerator()
        test_gen = ImageDataGenerator()
        # Train the model. The training will randomly sample 1 / N batches of
        # training data and 3 / N batches of validation data on every worker,
        # where N is the number of workers. Over-sampling of validation data
        # helps to increase probability that every validation example will be
        # evaluated.
        start_time = time.time()
        model.fit_generator(
            train_gen.flow(x_train, y_train, batch_size=batch_size),
            steps_per_epoch=steps_per_epoch,
            callbacks=callbacks,
            epochs=epochs,
            verbose=hvd.rank() == 0,
            validation_data=test_gen.flow(x_test, y_test,
                                          batch_size=batch_size),
            validation_steps=3 * test_batches // hvdsize)

    else:
        print('Using real-time data augmentation.')
        # This will do preprocessing and realtime data augmentation:
        datagen = ImageDataGenerator(
            featurewise_center=False,  # set input mean to 0 over the dataset
            samplewise_center=False,  # set each sample mean to 0
            # divide inputs by std of the dataset
            featurewise_std_normalization=False,
            samplewise_std_normalization=False,  # divide each input by its std
            zca_whitening=False,  # apply ZCA whitening
            # randomly rotate images in the range (degrees, 0 to 180)
            rotation_range=0,
            # randomly shift images horizontally (fraction of total width)
            width_shift_range=0.1,
            # randomly shift images vertically (fraction of total height)
            height_shift_range=0.1,
            horizontal_flip=True,  # randomly flip images
            vertical_flip=False)  # randomly flip images

        # Compute quantities required for feature-wise normalization
        # (std, mean, and principal components if ZCA whitening is applied).
        datagen.fit(x_train)

        start_time = time.time()
        # Fit the model on the batches generated by datagen.flow().
        model.fit_generator(
            datagen.flow(x_train, y_train, batch_size=batch_size),
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            validation_data=(x_test, y_test),
            verbose=hvd.rank() == 0,
            callbacks=callbacks)

    if hvd.rank() == 0:
        elapsed_time = time.time() - start_time
        print('[{}] finished in {} s'
              .format('TRAINING', round(elapsed_time, 3)))

        metrics = model.evaluate(x=x_test, y=y_test, batch_size=batch_size)
        print('\nCIFAR VALIDATION LOSS, ACC: {}, {}'.format(*metrics))

    KB.clear_session()
コード例 #27
0
ファイル: callbacks.py プロジェクト: kioco/horovod
 def on_train_begin(self, logs=None):
     with tf.device(self.device):
         bcast_op = hvd.broadcast_global_variables(self.root_rank)
         K.get_session().run(bcast_op)
コード例 #28
0
def main(argv=None):
    # Initialize Horovod.
    hvd.init()

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    KB.set_session(tf.Session(config=config))

    # print('LOCAL RANK, OVERAL RANK: {}, {}'.format(hvd.local_rank(),
    #                                                hvd.rank()))

    ngpus = hvd.size()

    main.__doc__ = __doc__
    argv = sys.argv if argv is None else sys.argv.extend(argv)
    desc = main.__doc__  # .format(os.path.basename(__file__))
    # CLI parser
    args = _parser(desc)

    num_devices_tfrecord = 1
    height, width = 224, 224  # Image dimensions. Gets resized if not match.
    distort_color = args.distort_color
    data_dir = args.datadir
    batch_size = args.batch_size  # * ngpus
    epochs = args.epochs
    imgs_per_epoch = args.imgs_per_epoch

    # Fit the model using data from the TFRecord data tensors.
    device_minibatches = RecordInputImagenetPreprocessor.device_minibatches
    images_tfrecord, labels_tfrecord, nrecords = device_minibatches(
        num_devices_tfrecord, data_dir, batch_size,
        height, width, distort_color, val=False)
    images_tfrecord = images_tfrecord[0]
    labels_tfrecord = labels_tfrecord[0]

    # CASTING FOR KERAS
    # labels[device_num] = tf.cast(labels_tfrecord, dtype)
    nclasses = 1000
    labels_tfrecord = tf.one_hot(labels_tfrecord, nclasses)

    nimgs_to_use = imgs_per_epoch if imgs_per_epoch > 0 else nrecords
    steps_per_epoch = nimgs_to_use // batch_size // hvd.size()
    # steps_per_epoch = 100

    # batch_shape = images_tfrecord.get_shape().as_list()
    # images = Input(tensor=images_tfrecord, batch_shape=x_batch_shape)
    images = Input(tensor=images_tfrecord)
    model = ResNet50(input_tensor=images, weights=None)
    if hvd.rank() == 0:
        model.summary()

        print('Num images: {}'.format(nrecords))

        if nimgs_to_use < nrecords:
            print('Using {} images per epoch'.format(nimgs_to_use))

        # print('IMAGES_TFRECORD: {}'.format(images_tfrecord))
        # print('LABELS_TFRECORD: {}'.format(labels_tfrecord))

    # Add Horovod Distributed Optimizer from nvcnn.py
    # momentum = 0.9
    # lr = 0.1
    # learning_rate = tf.train.exponential_decay(
    #             lr,
    #             self.global_step,
    #             decay_steps=FLAGS.lr_decay_epochs * nstep_per_epoch,
    #             decay_rate=FLAGS.lr_decay_rate,
    #             staircase=True)
    # opt = tf.train.MomentumOptimizer(self.learning_rate, momentum,
    #                                  use_nesterov=True)

    # lr = 0.001 * ngpus
    # opt = tf.train.AdamOptimizer()
    # opt = hvd.DistributedOptimizer(opt)  # , use_locking=True)
    # opt = KO.TFOptimizer(opt)  # Required for tf.train based optimizers

    opt = KO.Adam()
    opt = hvd_keras.DistributedOptimizer(opt)

    model.compile(loss='categorical_crossentropy',
                  optimizer=opt,
                  # metrics=['accuracy'],
                  target_tensors=[labels_tfrecord])

    # Broadcast variables from rank 0 to all other processes.
    KB.get_session().run(hvd.broadcast_global_variables(0))

    callbacks = []
    if hvd.rank() == 0:
        callbacks += [BatchTiming(),
                      SamplesPerSec(ngpus * batch_size)]

    # RecordInput is a yield op which doesn't use queue runners or queues.
    # Start the queue runners.
    # sess = KB.get_session()

    # sess.run([tf.local_variables_initializer(),
    #           tf.global_variables_initializer()])

    # coord = tf.train.Coordinator()
    # threads = tf.train.start_queue_runners(sess, coord)

    start_time = time.time()
    model.fit(
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1)
    # verbose=hvd.rank() == 0)
    elapsed_time = time.time() - start_time

    if hvd.rank() == 0:
        print('[{}] finished in {} s'
              .format('TRAINING', round(elapsed_time, 3)))
        # loss = model.evaluate(None, None, steps=steps_per_epoch_val)

        images_tfrecord_val, labels_tfrecord_val, nrecords_val = \
            device_minibatches(num_devices_tfrecord, data_dir, batch_size,
                               height, width, distort_color, val=True)
        images_tfrecord_val = images_tfrecord_val[0]
        labels_tfrecord_val = labels_tfrecord_val[0]
        labels_tfrecord_val = tf.one_hot(labels_tfrecord_val, nclasses)

        # print('IMAGES_TFRECORD_VAL: {}'.format(images_tfrecord_val))
        # print('labels_tfrecord_val: {}'.format(labels_tfrecord_val))

        steps_per_epoch_val = nrecords_val // batch_size

        images_val = Input(tensor=images_tfrecord_val)
        model_val = model
        model_val.layers[0] = KL.InputLayer(input_tensor=images_val)
        model_val.compile(
            loss='categorical_crossentropy',
            optimizer=opt,
            metrics=['accuracy'],
            target_tensors=[labels_tfrecord_val])
        # model.summary()
        loss = model_val.evaluate(x=None, y=None, steps=steps_per_epoch_val)

        print('\nNum images evaluated, steps: {}, {}'.
              format(nrecords_val, steps_per_epoch_val))
        print('\nTest loss, acc: {}'.format(loss))
        # print('\nTest accuracy: {0}'.format(acc))

    # Clean up the TF session.
    # coord.request_stop()
    # coord.join(threads)

    KB.clear_session()  # do this for Horovod
コード例 #29
0
def main(input_path_train, input_path_validation, downsampling_fact,
         downsampling_mode, channels, data_format, label_id, blocks, weights,
         image_dir, checkpoint_dir, trn_sz, val_sz, loss_type, fs_type,
         optimizer, batch, batchnorm, num_epochs, dtype, chkpt, filter_sz,
         growth, disable_checkpoints, disable_imsave, tracing, trace_dir,
         output_sampling, scale_factor):

    #init horovod
    nvtx.RangePush("init horovod", 1)
    comm_rank = 0
    comm_local_rank = 0
    comm_size = 1
    comm_local_size = 1
    if horovod:
        hvd.init()
        comm_rank = hvd.rank()
        comm_local_rank = hvd.local_rank()
        comm_size = hvd.size()
        #not all horovod versions have that implemented
        try:
            comm_local_size = hvd.local_size()
        except:
            comm_local_size = 1
        if comm_rank == 0:
            print("Using distributed computation with Horovod: {} total ranks".
                  format(comm_size, comm_rank))
    nvtx.RangePop()  # init horovod

    #downsampling? recompute image dims
    image_height = image_height_orig // downsampling_fact
    image_width = image_width_orig // downsampling_fact

    #parameters
    per_rank_output = False
    loss_print_interval = 10

    #session config
    sess_config = tf.ConfigProto(
        inter_op_parallelism_threads=6,  #1
        intra_op_parallelism_threads=1,  #6
        log_device_placement=False,
        allow_soft_placement=True)
    sess_config.gpu_options.visible_device_list = str(comm_local_rank)
    sess_config.gpu_options.force_gpu_compatible = True

    #get data
    training_graph = tf.Graph()
    if comm_rank == 0:
        print("Loading data...")
    trn_data = load_data(input_path_train, True, trn_sz, horovod)
    val_data = load_data(input_path_validation, False, val_sz, horovod)
    if comm_rank == 0:
        print("Shape of trn_data is {}".format(trn_data.shape[0]))
        print("done.")

    #print some stats
    if comm_rank == 0:
        print("Num workers: {}".format(comm_size))
        print("Local batch size: {}".format(batch))
        if dtype == tf.float32:
            print("Precision: {}".format("FP32"))
        else:
            print("Precision: {}".format("FP16"))
        print("Batch normalization: {}".format(batchnorm))
        print("Blocks: {}".format(blocks))
        print("Growth rate: {}".format(growth))
        print("Filter size: {}".format(filter_sz))
        print("Channels: {}".format(channels))
        print("Loss type: {}".format(loss_type))
        print("Loss weights: {}".format(weights))
        print("Loss scale factor: {}".format(scale_factor))
        print("Output sampling target: {}".format(output_sampling))
        #print optimizer parameters
        for k, v in optimizer.items():
            print("Solver Parameters: {k}: {v}".format(k=k, v=v))
        #print("Optimizer type: {}".format(optimizer['opt_type']))
        print("Num training samples: {}".format(trn_data.shape[0]))
        print("Num validation samples: {}".format(val_data.shape[0]))
        print("Disable checkpoints: {}".format(disable_checkpoints))
        print("Disable image save: {}".format(disable_imsave))
        print("Downsampling factor: {}".format(downsampling_fact))
        print("Downsampling mode: {}".format(downsampling_mode))

    #compute epochs and stuff:
    if fs_type == "local":
        num_samples = trn_data.shape[0] // comm_local_size
    else:
        num_samples = trn_data.shape[0] // comm_size
    num_steps_per_epoch = num_samples // batch
    num_steps = num_epochs * num_steps_per_epoch
    if per_rank_output:
        print("Rank {} does {} steps per epoch".format(comm_rank,
                                                       num_steps_per_epoch))

    with training_graph.as_default():
        nvtx.RangePush("TF Init", 3)
        #create readers
        trn_reader = h5_input_reader(input_path_train,
                                     channels,
                                     weights,
                                     dtype,
                                     normalization_file="stats.h5",
                                     update_on_read=False,
                                     data_format=data_format,
                                     label_id=label_id,
                                     sample_target=output_sampling)
        val_reader = h5_input_reader(input_path_validation,
                                     channels,
                                     weights,
                                     dtype,
                                     normalization_file="stats.h5",
                                     update_on_read=False,
                                     data_format=data_format,
                                     label_id=label_id)
        #create datasets
        if fs_type == "local":
            trn_dataset = create_dataset(trn_reader,
                                         trn_data,
                                         batch,
                                         num_epochs,
                                         comm_local_size,
                                         comm_local_rank,
                                         dtype,
                                         shuffle=True)
            val_dataset = create_dataset(val_reader,
                                         val_data,
                                         batch,
                                         1,
                                         comm_local_size,
                                         comm_local_rank,
                                         dtype,
                                         shuffle=False)
        else:
            trn_dataset = create_dataset(trn_reader,
                                         trn_data,
                                         batch,
                                         num_epochs,
                                         comm_size,
                                         comm_rank,
                                         dtype,
                                         shuffle=True)
            val_dataset = create_dataset(val_reader,
                                         val_data,
                                         batch,
                                         1,
                                         comm_size,
                                         comm_rank,
                                         dtype,
                                         shuffle=False)

        #create iterators
        handle = tf.placeholder(tf.string,
                                shape=[],
                                name="iterator-placeholder")
        iterator = tf.data.Iterator.from_string_handle(
            handle, (dtype, tf.int32, dtype, tf.string),
            ((batch, len(channels), image_height_orig,
              image_width_orig) if data_format == "channels_first" else
             (batch, image_height_orig, image_width_orig, len(channels)),
             (batch, image_height_orig, image_width_orig),
             (batch, image_height_orig, image_width_orig), (batch)))
        next_elem = iterator.get_next()

        #if downsampling, do some preprocessing
        if downsampling_fact != 1:
            if downsampling_mode == "scale":
                #do downsampling
                rand_select = tf.cast(tf.one_hot(tf.random_uniform(
                    (batch, image_height, image_width),
                    minval=0,
                    maxval=downsampling_fact * downsampling_fact,
                    dtype=tf.int32),
                                                 depth=downsampling_fact *
                                                 downsampling_fact,
                                                 axis=-1),
                                      dtype=tf.int32)
                next_elem = (tf.layers.average_pooling2d(next_elem[0], downsampling_fact, downsampling_fact, 'valid', data_format), \
                             tf.reduce_max(tf.multiply(tf.image.extract_image_patches(tf.expand_dims(next_elem[1], axis=-1), \
                                                                                 [1, downsampling_fact, downsampling_fact, 1], \
                                                                                 [1, downsampling_fact, downsampling_fact, 1], \
                                                                                 [1,1,1,1], 'VALID'), rand_select), axis=-1), \
                             tf.squeeze(tf.layers.average_pooling2d(tf.expand_dims(next_elem[2], axis=-1), downsampling_fact, downsampling_fact, 'valid', "channels_last"), axis=-1), \
                             next_elem[3])
            elif downsampling_mode == "center-crop":
                #some parameters
                length = 1. / float(downsampling_fact)
                offset = length / 2.
                boxes = [[offset, offset, offset + length, offset + length]
                         ] * batch
                box_ind = list(range(0, batch))
                crop_size = [image_height, image_width]

                #be careful with data order
                if data_format == "channels_first":
                    next_elem = (tf.transpose(next_elem[0], perm=[0, 2, 3, 1]),
                                 next_elem[1], next_elem[2], next_elem[3])

                #crop
                next_elem = (tf.image.crop_and_resize(next_elem[0], boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="data_cropping"), \
                             ensure_type(tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[1],axis=-1), boxes, box_ind, crop_size, method='nearest', extrapolation_value=0, name="label_cropping"), axis=-1), tf.int32), \
                             tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[2],axis=-1), boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="weight_cropping"), axis=-1), \
                             next_elem[3])

                #be careful with data order
                if data_format == "channels_first":
                    next_elem = (tf.transpose(next_elem[0], perm=[0, 3, 1, 2]),
                                 next_elem[1], next_elem[2], next_elem[3])

            elif downsampling_mode == "random-crop":
                #some parameters
                crop_size = [
                    batch, image_height, image_width,
                    len(channels) + 2
                ]

                #concatenate input, crop, split apart
                crop_input = tf.concat([next_elem[0] if data_format=="channels_last" else tf.transpose(next_elem[0], perm=[0,2,3,1]), \
                                        ensure_type(tf.expand_dims(next_elem[1], axis=-1), tf.float32), \
                                        tf.expand_dims(next_elem[2], axis=-1)], \
                                       axis = -1)
                crop_output = tf.image.random_crop(crop_input, crop_size)

                #restore iterator output
                crop_image = crop_output[:, :, :, :len(channels)]
                crop_label = ensure_type(crop_output[:, :, :,
                                                     len(channels)], tf.int32)
                crop_weight = crop_output[:, :, :, len(channels) + 1]
                next_elem = (crop_image if data_format=="channels_last" else tf.transpose(crop_image, perm=[0,3,1,2]), \
                             crop_label, crop_weight, next_elem[3])

            else:
                raise ValueError(
                    "Error, downsampling mode {} not supported. Supported are [center-crop, random-crop, scale]"
                    .format(downsampling_mode))

        #create init handles
        #trn
        trn_iterator = trn_dataset.make_initializable_iterator()
        trn_handle_string = trn_iterator.string_handle()
        trn_init_op = iterator.make_initializer(trn_dataset)
        #val
        val_iterator = val_dataset.make_initializable_iterator()
        val_handle_string = val_iterator.string_handle()
        val_init_op = iterator.make_initializer(val_dataset)

        #compute the input filter number based on number of channels used
        num_channels = len(channels)
        nb_filter = 64

        #set up model
        logit, prediction = create_tiramisu(3,
                                            next_elem[0],
                                            image_height,
                                            image_width,
                                            num_channels,
                                            loss_weights=weights,
                                            nb_layers_per_block=blocks,
                                            p=0.2,
                                            wd=1e-4,
                                            dtype=dtype,
                                            batchnorm=batchnorm,
                                            growth_rate=growth,
                                            nb_filter=nb_filter,
                                            filter_sz=filter_sz,
                                            median_filter=False,
                                            data_format=data_format)
        #prediction_argmax = median_pool(prediction_argmax, 3, strides=[1,1,1,1])

        #set up loss
        loss = None
        if loss_type == "weighted":
            #cast weights to FP32
            w_cast = ensure_type(next_elem[2], tf.float32)
            loss = tf.losses.sparse_softmax_cross_entropy(
                labels=next_elem[1],
                logits=logit,
                weights=w_cast,
                reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
            if scale_factor != 1.0:
                loss *= scale_factor
        elif loss_type == "focal":
            labels_one_hot = tf.contrib.layers.one_hot_encoding(
                next_elem[1], 3)
            labels_one_hot = ensure_type(labels_one_hot, dtype)
            loss = focal_loss(onehot_labels=labels_one_hot,
                              logits=logit,
                              alpha=1.,
                              gamma=2.)
        else:
            raise ValueError("Error, loss type {} not supported.",
                             format(loss_type))

        #determine flops
        flops = graph_flops.graph_flops(
            format="NHWC" if data_format == "channels_last" else "NCHW",
            batch=batch,
            sess_config=sess_config)
        flops *= comm_size
        if comm_rank == 0:
            print('training flops: {:.3f} TF/step'.format(flops * 1e-12))

        #number of trainable parameters
        if comm_rank == 0:
            num_params = get_number_of_trainable_parameters()
            print('number of trainable parameters: {} ({} MB)'.format(
                num_params,
                num_params * (4 if dtype == tf.float32 else 2) * (2**-20)))

        if horovod:
            loss_avg = hvd.allreduce(ensure_type(loss, tf.float32))
        else:
            loss_avg = tf.identity(loss)

        #set up global step - keep on CPU
        with tf.device('/device:CPU:0'):
            global_step = tf.train.get_or_create_global_step()

        #set up optimizer
        if optimizer['opt_type'].startswith("LARC"):
            if comm_rank == 0:
                print("Enabling LARC")
            train_op, lr = get_larc_optimizer(optimizer, loss, global_step,
                                              num_steps_per_epoch, horovod)
        else:
            train_op, lr = get_optimizer(optimizer, loss, global_step,
                                         num_steps_per_epoch, horovod)

        #set up streaming metrics
        iou_op, iou_update_op = tf.metrics.mean_iou(labels=next_elem[1],
                                                    predictions=tf.argmax(
                                                        prediction, axis=3),
                                                    num_classes=3,
                                                    weights=None,
                                                    metrics_collections=None,
                                                    updates_collections=None,
                                                    name="iou_score")
        iou_reset_op = tf.variables_initializer([
            i for i in tf.local_variables() if i.name.startswith('iou_score/')
        ])

        if horovod:
            iou_avg = hvd.allreduce(iou_op)
        else:
            iou_avg = tf.identity(iou_op)

        #hooks
        #these hooks are essential. regularize the step hook by adding one additional step at the end
        hooks = [tf.train.StopAtStepHook(last_step=num_steps + 1)]
        #bcast init for bcasting the model after start
        if horovod:
            init_bcast = hvd.broadcast_global_variables(0)

        #initializers:
        init_op = tf.global_variables_initializer()
        init_local_op = tf.local_variables_initializer()

        #checkpointing
        if comm_rank == 0:
            checkpoint_save_freq = 5 * num_steps_per_epoch
            checkpoint_saver = tf.train.Saver(max_to_keep=1000)
            if (not disable_checkpoints):
                hooks.append(
                    tf.train.CheckpointSaverHook(
                        checkpoint_dir=checkpoint_dir,
                        save_steps=checkpoint_save_freq,
                        saver=checkpoint_saver))
            #create image dir if not exists
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)

        if tracing is not None:
            import tracehook
            tracing_hook = tracehook.TraceHook(steps_to_trace=tracing,
                                               cache_traces=True,
                                               trace_dir=trace_dir)
            hooks.append(tracing_hook)

        # instead of averaging losses over an entire epoch, use a moving
        #  window average
        recent_losses = []
        loss_window_size = 10

        #start session
        with tf.train.MonitoredTrainingSession(config=sess_config,
                                               hooks=hooks) as sess:
            #initialize
            sess.run([init_op, init_local_op])
            #restore from checkpoint:
            if comm_rank == 0 and not disable_checkpoints:
                load_model(sess, checkpoint_saver, checkpoint_dir)
            #broadcast loaded model variables
            if horovod:
                sess.run(init_bcast)
            #create iterator handles
            trn_handle, val_handle = sess.run(
                [trn_handle_string, val_handle_string])
            #init iterators
            sess.run(trn_init_op, feed_dict={handle: trn_handle})
            sess.run(val_init_op, feed_dict={handle: val_handle})

            nvtx.RangePop()  # TF Init

            # figure out what step we're on (it won't be 0 if we are
            #  restoring from a checkpoint) so we can count from there
            train_steps = sess.run([global_step])[0]

            #do the training
            epoch = 1
            step = 1

            t_sustained_start = time.time()

            nvtx.RangePush("Training Loop", 4)
            nvtx.RangePush("Epoch", epoch)
            start_time = time.time()
            while not sess.should_stop():

                #training loop
                try:
                    nvtx.RangePush("Step", step)
                    #construct feed dict
                    t_inst_start = time.time()
                    _, tmp_loss = sess.run(
                        [train_op, (loss if per_rank_output else loss_avg)],
                        feed_dict={handle: trn_handle})
                    t_inst_end = time.time()
                    train_steps += 1
                    train_steps_in_epoch = train_steps % num_steps_per_epoch
                    recent_losses = [tmp_loss
                                     ] + recent_losses[0:loss_window_size - 1]
                    train_loss = sum(recent_losses) / len(recent_losses)
                    nvtx.RangePop()  # Step
                    step += 1

                    #print step report
                    eff_steps = train_steps_in_epoch if (
                        train_steps_in_epoch > 0) else num_steps_per_epoch
                    if (train_steps % loss_print_interval) == 0:
                        if per_rank_output:
                            print(
                                "REPORT: rank {}, training loss for step {} (of {}) is {}, time {:.3f}"
                                .format(comm_rank, train_steps, num_steps,
                                        train_loss,
                                        time.time() - start_time))
                        else:
                            if comm_rank == 0:
                                print(
                                    "REPORT: training loss for step {} (of {}) is {}, time {:.3f}, r_inst {:.3f}"
                                    .format(
                                        train_steps, num_steps, train_loss,
                                        time.time() - start_time, 1e-12 *
                                        flops / (t_inst_end - t_inst_start)))

                    #do the validation phase
                    if train_steps_in_epoch == 0:
                        end_time = time.time()
                        #print epoch report
                        if per_rank_output:
                            print(
                                "COMPLETED: rank {}, training loss for epoch {} (of {}) is {}, time {:.3f}, r_sust {:.3f}"
                                .format(
                                    comm_rank, epoch, num_epochs, train_loss,
                                    time.time() - start_time,
                                    1e-12 * flops * num_steps_per_epoch /
                                    (end_time - t_sustained_start)))
                        else:
                            if comm_rank == 0:
                                print(
                                    "COMPLETED: training loss for epoch {} (of {}) is {}, time {:.3f}, r_sust {:.3f}"
                                    .format(
                                        epoch, num_epochs, train_loss,
                                        time.time() - start_time,
                                        1e-12 * flops * num_steps_per_epoch /
                                        (end_time - t_sustained_start)))

                        #evaluation loop
                        eval_loss = 0.
                        eval_steps = 0
                        nvtx.RangePush("Eval Loop", 7)
                        while True:
                            try:
                                #construct feed dict
                                _, tmp_loss, val_model_predictions, val_model_labels, val_model_filenames = sess.run(
                                    [
                                        iou_update_op,
                                        (loss
                                         if per_rank_output else loss_avg),
                                        prediction, next_elem[1], next_elem[3]
                                    ],
                                    feed_dict={handle: val_handle})

                                #print some images
                                if comm_rank == 0 and not disable_imsave:
                                    if have_imsave:
                                        imsave(
                                            image_dir + '/test_pred_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.png',
                                            np.argmax(
                                                val_model_predictions[0, ...],
                                                axis=2) * 100)
                                        imsave(
                                            image_dir + '/test_label_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.png',
                                            val_model_labels[0, ...] * 100)
                                        imsave(
                                            image_dir +
                                            '/test_combined_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.png', colormap[
                                                val_model_labels[0, ...],
                                                np.argmax(
                                                    val_model_predictions[0,
                                                                          ...],
                                                    axis=2)])
                                    else:
                                        np.savez(
                                            image_dir + '/test_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.npz',
                                            prediction=np.argmax(
                                                val_model_predictions[0, ...],
                                                axis=2) * 100,
                                            label=val_model_labels[0, ...] *
                                            100,
                                            filename=val_model_filenames[0])

                                eval_loss += tmp_loss
                                eval_steps += 1
                            except tf.errors.OutOfRangeError:
                                eval_steps = np.max([eval_steps, 1])
                                eval_loss /= eval_steps
                                if per_rank_output:
                                    print(
                                        "COMPLETED: rank {}, evaluation loss for epoch {} (of {}) is {}"
                                        .format(comm_rank, epoch, num_epochs,
                                                eval_loss))
                                else:
                                    if comm_rank == 0:
                                        print(
                                            "COMPLETED: evaluation loss for epoch {} (of {}) is {}"
                                            .format(epoch, num_epochs,
                                                    eval_loss))
                                if per_rank_output:
                                    iou_score = sess.run(iou_op)
                                    print(
                                        "COMPLETED: rank {}, evaluation IoU for epoch {} (of {}) is {}"
                                        .format(comm_rank, epoch, num_epochs,
                                                iou_score))
                                else:
                                    iou_score = sess.run(iou_avg)
                                    if comm_rank == 0:
                                        print(
                                            "COMPLETED: evaluation IoU for epoch {} (of {}) is {}"
                                            .format(epoch, num_epochs,
                                                    iou_score))
                                sess.run(iou_reset_op)
                                sess.run(val_init_op,
                                         feed_dict={handle: val_handle})
                                break
                        nvtx.RangePop()  # Eval Loop

                        #reset counters
                        epoch += 1
                        step = 0
                        t_sustained_start = time.time()

                        nvtx.RangePop()  # Epoch
                        nvtx.RangePush("Epoch", epoch)

                except tf.errors.OutOfRangeError:
                    break

            nvtx.RangePop()  # Epoch
            nvtx.RangePop()  # Training Loop

        # write any cached traces to disk
        if tracing is not None:
            tracing_hook.write_traces()
コード例 #30
0
def train_main(dataset,
               model_name='117M',
               seed=None,
               batch_size=2,
               sample_length=1023,
               sample_num=1,
               sample_every=4500,
               run_name='run1',
               restore_from='latest',
               save_every=2000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(
            os.path.join('chatbot_model', 'trained_models', model_name,
                         'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length is None:
        sample_length = hparams.n_ctx // 2
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.8,
                                           top_k=40)

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        opt = tf.train.AdamOptimizer()
        opt = hvd.DistributedOptimizer(opt)
        train_op = opt.minimize(loss, var_list=train_vars)

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        bcast = hvd.broadcast_global_variables(0)

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('chatbot_model', 'trained_models',
                                 model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('chatbot_model', 'trained_models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1]))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
コード例 #31
0
if args.eager:
    tf.enable_eager_execution(config)

# Set up standard model.
model = getattr(applications, args.model)(weights=None)

opt = tf.train.GradientDescentOptimizer(0.01)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
opt = hvd.DistributedOptimizer(opt, compression=compression)

init = tf.global_variables_initializer()
bcast_op = hvd.broadcast_global_variables(0)

data = tf.random_uniform([args.batch_size, 224, 224, 3])
target = tf.random_uniform([args.batch_size, 1], minval=0, maxval=999, dtype=tf.int64)


def loss_function():
    probs = model(data, training=True)
    return tf.losses.sparse_softmax_cross_entropy(target, probs)


def log(s, nl=True):
    if hvd.rank() != 0:
        return
    print(s, end='\n' if nl else '')
コード例 #32
0
ファイル: callbacks.py プロジェクト: tobyyouup/horovod
 def on_train_begin(self, logs=None):
     with tf.device(self.device):
         bcast_op = hvd.broadcast_global_variables(self.root_rank)
         K.get_session().run(bcast_op)
コード例 #33
0
ファイル: eval_wsdm_esim_bert.py プロジェクト: P79N6A/BERT
def main(_):

    hvd.init()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

    graph = tf.Graph()
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    with graph.as_default():
        import json

        # config = json.load(open("/data/xuht/bert/chinese_L-12_H-768_A-12/bert_config.json", "r"))

        config = json.load(open(FLAGS.config_file, "r"))

        init_checkpoint = FLAGS.init_checkpoint
        print("===init checkoutpoint==={}".format(init_checkpoint))

        # init_checkpoint = "/data/xuht/bert/chinese_L-12_H-768_A-12/bert_model.ckpt"
        # init_checkpoint = "/data/xuht/concat/model_1/oqmrc.ckpt"
        config = Bunch(config)
        config.use_one_hot_embeddings = True
        config.scope = "esim/bert"
        config.dropout_prob = 0.1
        config.label_type = "single_label"
        config.lstm_dim = 128
        config.num_heads = 4

        import json
        label_dict = json.load(open(FLAGS.label_id))

        # label_tensor = np.asarray(label_dict["class_ratio"]).astype(np.float32)
        label_tensor = None
        # config.loss = "focal_loss"

        json.dump(config, open(FLAGS.model_output + "/config.json", "w"))

        # os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id
        sess = tf.Session(config=sess_config)

        train_size = int(FLAGS.train_size / hvd.size())

        num_train_steps = int(train_size / FLAGS.batch_size * FLAGS.epoch)
        num_warmup_steps = int(num_train_steps * 0.01)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": (5e-5 / hvd.size()),
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps,
            "train_op": "adam"
        })

        model_io_config = Bunch({"fix_lm": True})

        model_io_fn = model_io.ModelIO(model_io_config)

        num_choice = FLAGS.num_classes
        max_seq_length = FLAGS.max_length

        if FLAGS.model_type == "original":
            model_function = bert_order_classifier.classifier_model_fn_builder
        elif FLAGS.model_type == "attn":
            model_function = bert_order_classifier.classifier_attn_model_fn_builder
        elif FLAGS.model_type == "orignal_nonlinear":
            model_function = bert_order_classifier.classifier_model_fn_builder_v1
        elif FLAGS.model_type == "esim_bert":
            model_function = esim_bert.classifier_attn_model_fn_builder

        model_eval_fn = model_function(config,
                                       num_choice,
                                       init_checkpoint,
                                       model_reuse=None,
                                       load_pretrained=True,
                                       model_io_fn=model_io_fn,
                                       model_io_config=model_io_config,
                                       opt_config=opt_config,
                                       input_name=["a", "b"],
                                       label_tensor=label_tensor,
                                       not_storage_params=["adam", "adam_1"],
                                       exclude_scope_dict={"task": "esim"})

        def metric_fn(features, logits, loss):
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["label_ids"], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
            return {
                "accuracy": accuracy,
                "loss": loss,
                "pred_label": pred_label,
                "label_ids": features["label_ids"]
            }

        name_to_features = {
            "input_ids_a": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask_a": tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids_a": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_ids_b": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask_b": tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids_b": tf.FixedLenFeature([max_seq_length], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
        }

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
            """
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t
            return example

        params = Bunch({})
        params.epoch = FLAGS.epoch
        params.batch_size = FLAGS.batch_size
        # train_features = tf_data_utils.train_input_fn("/data/xuht/wsdm19/data/train.tfrecords",
        #                             _decode_record, name_to_features, params)
        # eval_features = tf_data_utils.eval_input_fn("/data/xuht/wsdm19/data/dev.tfrecords",
        #                             _decode_record, name_to_features, params)

        # train_features = tf_data_utils.train_input_fn(FLAGS.train_file,
        #                             _decode_record, name_to_features, params)
        eval_features = tf_data_utils.eval_input_fn(FLAGS.dev_file,
                                                    _decode_record,
                                                    name_to_features, params)

        # [train_op, train_loss, train_per_example_loss, train_logits] = model_train_fn(train_features, [], tf.estimator.ModeKeys.TRAIN)
        [_, eval_loss, eval_per_example_loss,
         eval_logits] = model_eval_fn(eval_features, [],
                                      tf.estimator.ModeKeys.EVAL)
        result = metric_fn(eval_features, eval_logits, eval_loss)

        model_io_fn.set_saver()

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        model_io_fn.load_model(sess, init_checkpoint)
        print(" ==succeeded in loading model== ")

        sess.run(hvd.broadcast_global_variables(0))

        def eval_fn(result):
            i = 0
            total_accuracy = 0
            label, label_id = [], []
            # label_weight = []
            while True:
                try:
                    eval_result = sess.run(result)
                    total_accuracy += eval_result["accuracy"]
                    label_id.extend(eval_result["label_ids"])
                    label.extend(eval_result["pred_label"])
                    # for item in eval_result["label_ids"]:
                    #     label_weight.append(label_tensor[item])
                    i += 1
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break
            # f1 = f1_score(label_id, label, average="macro", sample_weight=label_weight)
            # accuracy = accuracy_score(label_id, label, sample_weight=label_weight)
            f1 = f1_score(label_id, label, average="macro")
            accuracy = accuracy_score(label_id, label)
            print("test accuracy accuracy {} {} f1 {}".format(
                total_accuracy / i, accuracy, f1))
            return total_accuracy / i, f1

        # print("===========begin to train============")
        # train_fn(train_op, train_loss)
        print("===========begin to eval============")
        accuracy, f1 = eval_fn(result)
        print("==accuracy {} f1 {}==".format(accuracy, f1))
コード例 #34
0
def main(argv=None):
    # Initialize Horovod.
    hvd.init()

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    KB.set_session(tf.Session(config=config))

    # print('LOCAL RANK, OVERAL RANK: {}, {}'.format(hvd.local_rank(),
    #                                                hvd.rank()))

    ngpus = hvd.size()

    main.__doc__ = __doc__
    argv = sys.argv if argv is None else sys.argv.extend(argv)
    desc = main.__doc__  # .format(os.path.basename(__file__))
    # CLI parser
    args = _parser(desc)

    num_devices_tfrecord = 1
    height, width = 224, 224  # Image dimensions. Gets resized if not match.
    distort_color = args.distort_color
    data_dir = args.datadir
    batch_size = args.batch_size  # * ngpus
    epochs = args.epochs
    imgs_per_epoch = args.imgs_per_epoch

    # Fit the model using data from the TFRecord data tensors.
    device_minibatches = RecordInputImagenetPreprocessor.device_minibatches
    images_tfrecord, labels_tfrecord, nrecords = device_minibatches(
        num_devices_tfrecord,
        data_dir,
        batch_size,
        height,
        width,
        distort_color,
        val=False)
    images_tfrecord = images_tfrecord[0]
    labels_tfrecord = labels_tfrecord[0]

    # CASTING FOR KERAS
    # labels[device_num] = tf.cast(labels_tfrecord, dtype)
    nclasses = 1000
    labels_tfrecord = tf.one_hot(labels_tfrecord, nclasses)

    nimgs_to_use = imgs_per_epoch if imgs_per_epoch > 0 else nrecords
    steps_per_epoch = nimgs_to_use // batch_size // hvd.size()
    # steps_per_epoch = 100

    # batch_shape = images_tfrecord.get_shape().as_list()
    # images = Input(tensor=images_tfrecord, batch_shape=x_batch_shape)
    images = Input(tensor=images_tfrecord)
    model = ResNet50(input_tensor=images, weights=None)
    if hvd.rank() == 0:
        model.summary()

        print('Num images: {}'.format(nrecords))

        if nimgs_to_use < nrecords:
            print('Using {} images per epoch'.format(nimgs_to_use))

        # print('IMAGES_TFRECORD: {}'.format(images_tfrecord))
        # print('LABELS_TFRECORD: {}'.format(labels_tfrecord))

    # Add Horovod Distributed Optimizer from nvcnn.py
    # momentum = 0.9
    # lr = 0.1
    # learning_rate = tf.train.exponential_decay(
    #             lr,
    #             self.global_step,
    #             decay_steps=FLAGS.lr_decay_epochs * nstep_per_epoch,
    #             decay_rate=FLAGS.lr_decay_rate,
    #             staircase=True)
    # opt = tf.train.MomentumOptimizer(self.learning_rate, momentum,
    #                                  use_nesterov=True)

    # lr = 0.001 * ngpus
    # opt = tf.train.AdamOptimizer()
    # opt = hvd.DistributedOptimizer(opt)  # , use_locking=True)
    # opt = KO.TFOptimizer(opt)  # Required for tf.train based optimizers

    opt = KO.Adam()
    opt = hvd_keras.DistributedOptimizer(opt)

    model.compile(
        loss='categorical_crossentropy',
        optimizer=opt,
        # metrics=['accuracy'],
        target_tensors=[labels_tfrecord])

    # Broadcast variables from rank 0 to all other processes.
    KB.get_session().run(hvd.broadcast_global_variables(0))

    callbacks = []
    if hvd.rank() == 0:
        callbacks += [BatchTiming(), SamplesPerSec(ngpus * batch_size)]

    # RecordInput is a yield op which doesn't use queue runners or queues.
    # Start the queue runners.
    # sess = KB.get_session()

    # sess.run([tf.local_variables_initializer(),
    #           tf.global_variables_initializer()])

    # coord = tf.train.Coordinator()
    # threads = tf.train.start_queue_runners(sess, coord)

    start_time = time.time()
    model.fit(steps_per_epoch=steps_per_epoch,
              epochs=epochs,
              callbacks=callbacks,
              verbose=1)
    # verbose=hvd.rank() == 0)
    elapsed_time = time.time() - start_time

    if hvd.rank() == 0:
        print('[{}] finished in {} s'.format('TRAINING',
                                             round(elapsed_time, 3)))
        # loss = model.evaluate(None, None, steps=steps_per_epoch_val)

        images_tfrecord_val, labels_tfrecord_val, nrecords_val = \
            device_minibatches(num_devices_tfrecord, data_dir, batch_size,
                               height, width, distort_color, val=True)
        images_tfrecord_val = images_tfrecord_val[0]
        labels_tfrecord_val = labels_tfrecord_val[0]
        labels_tfrecord_val = tf.one_hot(labels_tfrecord_val, nclasses)

        # print('IMAGES_TFRECORD_VAL: {}'.format(images_tfrecord_val))
        # print('labels_tfrecord_val: {}'.format(labels_tfrecord_val))

        steps_per_epoch_val = nrecords_val // batch_size

        images_val = Input(tensor=images_tfrecord_val)
        model_val = model
        model_val.layers[0] = KL.InputLayer(input_tensor=images_val)
        model_val.compile(loss='categorical_crossentropy',
                          optimizer=opt,
                          metrics=['accuracy'],
                          target_tensors=[labels_tfrecord_val])
        # model.summary()
        loss = model_val.evaluate(x=None, y=None, steps=steps_per_epoch_val)

        print('\nNum images evaluated, steps: {}, {}'.format(
            nrecords_val, steps_per_epoch_val))
        print('\nTest loss, acc: {}'.format(loss))
        # print('\nTest accuracy: {0}'.format(acc))

    # Clean up the TF session.
    # coord.request_stop()
    # coord.join(threads)

    KB.clear_session()  # do this for Horovod
コード例 #35
0
def train_main(dataset,
               model_name='1250M',
               seed=None,
               msg=True,
               batch_size=16,
               learning_rate=0.00002,
               sample_length=512,
               sample_num=1,
               sample_every=100,
               run_name='run1',
               restore_from='latest',
               save_every=1000,
               combine=50000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
        print('n_ctx: ', hparams.n_ctx, 'n_head: ', hparams.n_head, 'n_embd: ',
              hparams.n_embd, 'n_layer: ', hparams.n_layer)

    if sample_length is None:
        sample_length = hparams.n_ctx
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    # TF config

    config = tf.ConfigProto()
    #device_map = { 0:2, 0:3, 1:2, 1:3 }
    #config.gpu_options.visible_device_list = str(device_map[hvd.rank()])
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.allow_growth = True

    global_step = tf.Variable(0, trainable=False)

    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=0.9,
                                           top_k=40)

        #global_step = tf.Variable(0, trainable=False)
        counter = 1

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        #opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # l4rz 11/10/2019
        decayed_lr = tf.train.exponential_decay(learning_rate,
                                                global_step,
                                                200,
                                                0.999,
                                                staircase=True)
        opt = tf.train.AdamOptimizer(decayed_lr)
        #opt = tf.train.GradientDescentOptimizer(decayed_lr)
        opt = hvd.DistributedOptimizer(opt)
        # this is original horovod
        #train_op = opt.minimize(loss, var_list=train_vars)
        # this is ours
        if (msg):
            print('Using memory saving gradients')
            opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            train_op = opt.apply_gradients(opt_grads, global_step=global_step)
        else:
            print('Not using memory saving gradients')
            #train_op = opt.minimize(loss, var_list=train_vars)
            # l4rz 11/10
            train_op = opt.minimize(loss,
                                    var_list=train_vars,
                                    global_step=global_step)
        # [1,2]<stderr>:TypeError: apply_gradients() missing 1 required positional argument: 'grads_and_vars'
        #summary_loss = tf.summary.scalar('loss', train_op)

        #_, lv = sess.run((train_op, loss), feed_dict={context: batch})

        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        print('Running hvd.broadcast_global_variables')
        bcast = hvd.broadcast_global_variables(0)
        print('Done')

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)

        print('Running global_variables_initializer')
        sess.run(tf.global_variables_initializer())
        print('Done')

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        # comment out when running for 1st time
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        # uncomment when running for first time INIT THE MODEL
        #print('tf.global_variables_initializer()')
        #sess.run(tf.global_variables_initializer())

        bcast.run()

        print(str(hvd.local_rank()), 'Loading dataset...')
        chunks = load_dataset(enc, dataset, combine)
        data_sampler = Sampler(chunks)
        print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size,
              'tokens')
        print(str(hvd.local_rank()), 'Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:

                batch = [data_sampler.sample(1024) for _ in range(batch_size)]

                _, lv = sess.run((train_op, loss), feed_dict={context: batch})

                avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)

                if hvd.rank() == 0:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.4f} avg={avg:2.4f} lr={lr:.2e}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1],
                                lr=decayed_lr.eval()))

                counter += 1

        except KeyboardInterrupt:
            print('interrupted')
            if hvd.rank() == 0:
                save()
コード例 #36
0
def main(_):
    """
    Builds the model and runs
    """
    if FLAGS.distributed:
        import horovod.tensorflow as hvd
        hvd.init()

    tf.logging.set_verbosity(tf.logging.INFO)

    if len(config_train.name) > 0:
        output_dir = os.path.join(FLAGS.output_dir, config_train.name)
    else:
        output_dir = FLAGS.output_dir
    tx.utils.maybe_create_dir(output_dir)

    ## Loads GPT-2 model configuration

    if FLAGS.config_type == "json":
        gpt2_config = model_utils.transform_gpt2_to_texar_config(
            FLAGS.config_model)
    elif FLAGS.config_type == 'texar':
        gpt2_config = importlib.import_module(FLAGS.config_model)
    else:
        raise ValueError('Unknown config_type.')

    # Creates a data pre-processor for, e.g., BPE encoding
    proc = processor.get_encoder(FLAGS.pretrained_model_dir)

    max_decoding_length = config_train.max_decoding_length
    assert max_decoding_length <= gpt2_config.position_size, (
        "max_decoding_length should not be greater than position_size. "
        "{}>{}".format(max_decoding_length, gpt2_config.position_size))

    ## Loads data

    # Configures training data shard in distribued mode
    if FLAGS.distributed:
        config_train.train_hparam["dataset"]["num_shards"] = hvd.size()
        config_train.train_hparam["dataset"]["shard_id"] = hvd.rank()
        config_train.train_hparam["batch_size"] //= hvd.size()

    datasets = {}
    #if FLAGS.do_train:
    train_dataset = tx.data.TFRecordData(hparams=config_train.train_hparam)
    datasets['train'] = train_dataset
    #if FLAGS.do_eval:
    dev_dataset = tx.data.TFRecordData(hparams=config_train.dev_hparam)
    datasets['dev'] = dev_dataset
    #if FLAGS.do_test:
    test_dataset = tx.data.TFRecordData(hparams=config_train.test_hparam)
    datasets['test'] = test_dataset
    iterator = tx.data.FeedableDataIterator(datasets)
    batch = iterator.get_next()
    batch_size = tf.shape(batch['x1x2yx1xx2_ids'])[0]

    ## Builds the GPT-2 model
    vocab_size = gpt2_config.vocab_size

    word_embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                            hparams=gpt2_config.embed)

    pos_embedder = tx.modules.PositionEmbedder(
        position_size=gpt2_config.position_size, hparams=gpt2_config.pos_embed)

    # Ties output layer with input word embedding
    output_layer = tf.transpose(word_embedder.embedding, (1, 0))

    decoder = tx.modules.TransformerDecoder(vocab_size=vocab_size,
                                            output_layer=output_layer,
                                            hparams=gpt2_config.decoder)

    # For training
    def _get_recon_loss(ids,
                        full_len,
                        prefix_len,
                        mask_prefix=True,
                        do_print=False):
        ids = ids[:, :tf.reduce_max(full_len)]
        batch_size__ = tf.shape(ids)[0]
        seq_len = tf.fill([batch_size__], tf.shape(ids)[1])
        pos_embeds = pos_embedder(sequence_length=seq_len)
        input_embeds = word_embedder(ids) + pos_embeds

        outputs = decoder(inputs=input_embeds,
                          decoding_strategy='train_greedy')

        max_full_len = tf.reduce_max(full_len)
        ids = ids[:, :max_full_len]
        logits = outputs.logits[:, :max_full_len]

        if mask_prefix:
            loss_recon = tx.losses.sequence_sparse_softmax_cross_entropy(
                labels=ids[:, 1:],
                logits=logits[:, :-1, :],
                sequence_length=full_len - 1,
                average_across_timesteps=False,
                sum_over_timesteps=False,
                average_across_batch=False,
                sum_over_batch=False)
            mask_recon = tf.sequence_mask(full_len - 1, dtype=tf.float32)
            mask_recon_prefix = 1 - tf.sequence_mask(
                prefix_len - 1,
                maxlen=max_full_len - 1,  #max_decoding_length-1,
                dtype=tf.float32)
            mask_recon = mask_recon * mask_recon_prefix

            if do_print:
                print_op_1 = tf.print(mask_recon)
                loss_recon_flat = tx.utils.reduce_with_weights(
                    tensor=loss_recon,
                    weights=mask_recon,
                    average_across_remaining=False,
                    sum_over_remaining=False,
                    average_across_batch=False)
                print_op_2 = tf.print(loss_recon_flat)
                with tf.control_dependencies([print_op_1, print_op_2]):
                    loss_recon = tx.utils.reduce_with_weights(
                        tensor=loss_recon,
                        weights=mask_recon,
                        average_across_remaining=True,
                        sum_over_remaining=False)
                return loss_recon, mask_recon, loss_recon_flat
            else:
                loss_recon = tx.utils.reduce_with_weights(
                    tensor=loss_recon,
                    weights=mask_recon,
                    average_across_remaining=True,
                    sum_over_remaining=False)
        else:
            loss_recon = tx.losses.sequence_sparse_softmax_cross_entropy(
                labels=ids[:, 1:],
                logits=logits[:, :-1, :],
                sequence_length=full_len - 1,
                average_across_timesteps=True,
                sum_over_timesteps=False,
                average_across_batch=True,
                sum_over_batch=False)

        return loss_recon

    ## Loss-(1): mask reconstruction loss
    x1x2yx1my_ids = tf.placeholder(tf.int32,
                                   shape=[None, None],
                                   name='x1x2yx1my_ids')
    x1x2yx1my_len = tf.placeholder(tf.int32,
                                   shape=[None],
                                   name='x1x2yx1my_len')
    x1x2yx1m_len = tf.placeholder(tf.int32, shape=[None], name='x1x2yx1m_len')

    loss_mask_recon = _get_recon_loss(x1x2yx1my_ids, x1x2yx1my_len,
                                      x1x2yx1m_len)
    ppl_mask_recon = tf.exp(loss_mask_recon)

    ## Loss-(4): fine-tune loss
    x1x2_ids = tf.placeholder(tf.int32, shape=[None, None], name='x1x2_ids')
    x1x2_len = tf.placeholder(tf.int32, shape=[None], name='x1x2_len')
    x1x2y_ids = tf.placeholder(tf.int32, shape=[None, None], name='x1x2y_ids')
    x1x2y_len = tf.placeholder(tf.int32, shape=[None], name='x1x2y_len')

    loss_fine = _get_recon_loss(x1x2y_ids,
                                x1x2y_len,
                                x1x2_len,
                                mask_prefix=False)

    ## Loss-(5): xx2 loss
    x1_len = tf.placeholder(tf.int32, shape=[None], name='x1_len')
    x1xx2_ids = tf.placeholder(tf.int32, shape=[None, None], name='x1xx2_ids')
    x1xx2_len = tf.placeholder(tf.int32, shape=[None], name='x1xx2_len')

    loss_xx2 = _get_recon_loss(x1xx2_ids, x1xx2_len, x1_len, do_print=False)

    ## Loss-(6): yy loss
    x1x2yx1xx2_ids = tf.placeholder(tf.int32,
                                    shape=[None, None],
                                    name='x1x2yx1xx2_ids')
    x1x2yx1xx2_len = tf.placeholder(tf.int32,
                                    shape=[None],
                                    name='x1x2yx1xx2_len')
    x1x2yx1xx2yy_ids = tf.placeholder(tf.int32,
                                      shape=[None, None],
                                      name='x1x2yx1xx2yy_ids')
    x1x2yx1xx2yy_len = tf.placeholder(tf.int32,
                                      shape=[None],
                                      name='x1x2yx1xx2yy_len')

    loss_yy = _get_recon_loss(x1x2yx1xx2yy_ids, x1x2yx1xx2yy_len,
                              x1x2yx1xx2_len)

    ## Loss-(2): back-translation loss
    x1xx2yyx1x2y_ids = tf.placeholder(tf.int32,
                                      shape=[None, None],
                                      name='x1xx2yyx1x2y_ids')
    x1xx2yyx1x2y_len = tf.placeholder(tf.int32,
                                      shape=[None],
                                      name='x1xx2yyx1x2y_len')
    x1xx2yyx1x2_len = tf.placeholder(tf.int32,
                                     shape=[None],
                                     name='x1xx2yyx1x2_len')

    loss_bt = _get_recon_loss(x1xx2yyx1x2y_ids, x1xx2yyx1x2y_len,
                              x1xx2yyx1x2_len)
    ppl_bt = tf.exp(loss_bt)

    ## Loss-(3): contrastive loss
    D = Discriminator(gpt2_config)

    tau = tf.placeholder(tf.float32, shape=[], name='tau')

    # generate soft yy
    def _soft_embedding_fn(soft_ids, times):
        return word_embedder(soft_ids=soft_ids) + pos_embedder(times)

    end_token = proc.encoder['<|endoftext|>']
    start_tokens = x1x2yx1xx2_ids[:, 0]

    helper_soft = tx.modules.SoftmaxEmbeddingHelper(
        embedding=_soft_embedding_fn,
        start_tokens=start_tokens,
        end_token=end_token,
        tau=tau,
        embedding_size=vocab_size)

    outputs_soft, len_soft = decoder(context=tf.one_hot(x1x2yx1xx2_ids,
                                                        depth=vocab_size),
                                     context_sequence_length=x1x2yx1xx2_len,
                                     max_decoding_length=max_decoding_length,
                                     helper=helper_soft)
    yy_soft_ids = tx.utils.varlength_roll(outputs_soft.sample_id,
                                          -x1x2yx1xx2_len)
    yy_soft_len = len_soft - x1x2yx1xx2_len
    yy_soft_ids = yy_soft_ids[:, :tf.reduce_max(yy_soft_len), :]

    def _get_d_loss(prefix_ids, post_soft_ids, prefix_len, post_len):
        onehot_prefix_ids = tf.one_hot(prefix_ids, depth=vocab_size)
        soft_ids = tx.utils.varlength_concat(onehot_prefix_ids, post_soft_ids,
                                             prefix_len)
        soft_len = prefix_len + post_len
        return D.compute_loss(soft_ids, soft_len), soft_ids, soft_len

    loss_d_x2, _, _ = _get_d_loss(x1x2_ids, yy_soft_ids, x1x2_len,
                                  yy_soft_len)  # to maximize
    loss_d_xx2, x1xx2yy_soft_ids, x1xx2yy_len = _get_d_loss(
        x1xx2_ids, yy_soft_ids, x1xx2_len, yy_soft_len)  # to minimize

    x1xx2yy_ids = tf.argmax(x1xx2yy_soft_ids, axis=-1)

    if not FLAGS.supervised:
        loss = config_train.w_recon * loss_mask_recon \
                + config_train.w_fine * loss_fine \
                + config_train.w_xx2 * loss_xx2

        loss_dict = {
            'loss': loss,
            'loss_mask_recon': config_train.w_recon * loss_mask_recon,
            'loss_bt': tf.constant(0),  #config_train.w_bt * loss_bt,
            'loss_d_xx2': tf.constant(0),  #config_train.w_d_xx2 * loss_d_xx2,
            'loss_d_x2': tf.constant(0),  #config_train.w_d_x2 * loss_d_x2,
            'loss_fine': config_train.w_fine * loss_fine,
            'loss_xx2': config_train.w_xx2 * loss_xx2,
        }
    else:
        loss = loss_yy

        loss_dict = {
            'loss': loss,
            'loss_yy': loss_yy,
            # dumb
            'loss_mask_recon': tf.constant(0),
            'loss_bt': tf.constant(0),
            'loss_d_xx2': tf.constant(0),
            'loss_d_x2': tf.constant(0),
            'loss_fine': tf.constant(0),
            'loss_xx2': tf.constant(0)
        }

    ## Inference
    def _embedding_fn(ids, times):
        return word_embedder(ids) + pos_embedder(times)

    def _infer(context_name):
        helper = tx.modules.TopKSampleEmbeddingHelper(
            embedding=_embedding_fn,
            start_tokens=batch['%s_ids' % context_name][:, 0],
            end_token=end_token,
            top_k=FLAGS.top_k,
            softmax_temperature=FLAGS.temperature)
        outputs_infer, len_infer = decoder(
            context=batch['%s_ids' % context_name],
            context_sequence_length=batch['%s_len' % context_name],
            max_decoding_length=max_decoding_length,
            helper=helper)
        yy_ids = tx.utils.varlength_roll(outputs_infer.sample_id,
                                         -batch['%s_len' % context_name])
        yy_len = len_infer - batch['%s_len' % context_name]
        yy_ids = yy_ids[:, :tf.reduce_max(yy_len)]
        return yy_ids, yy_len

    yy_ids, yy_len = _infer('x1x2yx1xx2')
    yy_ids_fine, yy_len_fine = _infer('x1xx2')  # used in fine-tune
    yy_ids_roc, yy_len_roc = _infer('x1x2')  # used in fine-tune
    ## Optimization
    trainable_variables = tx.utils.collect_trainable_variables(
        [word_embedder, pos_embedder, decoder])

    global_step = tf.Variable(0, trainable=False)
    opt = tx.core.get_optimizer(global_step=global_step,
                                hparams=config_train.opt)

    if FLAGS.distributed:
        opt = hvd.DistributedOptimizer(opt)

    train_op = tf.contrib.layers.optimize_loss(loss=loss,
                                               global_step=global_step,
                                               learning_rate=None,
                                               optimizer=opt,
                                               variables=trainable_variables)

    ## Train/eval/test routine
    saver = tf.train.Saver()
    saver_best = tf.train.Saver(max_to_keep=1)
    dev_best = {
        'loss': 1e8,
        'loss_mask_recon': 1e8,
        'loss_bt': 1e8,
        'loss_d_x1': 1e8,
        'loss_d_xx2': 1e8,
        'loss_fine': 1e8,
        'loss_xx2': 1e8
    }

    def _log_losses(losses, step=None):
        loss_str = 'loss: %.4f, loss_mask_recon: %.4f, loss_bt: %.4f, loss_d_xx2: %.4f, loss_d_x2: %.4f, loss_fine: %.4f, loss_xx2: %.4f' % \
            (losses['loss'], losses['loss_mask_recon'], losses['loss_bt'],
             losses['loss_d_xx2'], losses['loss_d_x2'], losses['loss_fine'], losses['loss_xx2'])

        if step is not None:
            loss_str = 'step: %d, %s' % (step, loss_str)

        _log(loss_str)

    def _insert_yy(rets):
        batch_ = rets['batch']
        batch_size_ = rets['batch_size']
        yy_ids_ = rets['yy_ids']
        yy_len_ = rets['yy_len']

        x1x2y_ids_ = batch_['x1x2y_ids']
        x1x2y_len_ = batch_['x1x2y_len']

        x1xx2_ids_ = batch_['x1xx2_ids']
        x1xx2_len_ = batch_['x1xx2_len']

        x1xx2yy_ids_ = tx.utils.varlength_concat_py(x1xx2_ids_, yy_ids_,
                                                    x1xx2_len_)
        x1xx2yy_len_ = x1xx2_len_ + yy_len_
        x1xx2yyx1x2y_ids_ = tx.utils.varlength_concat_py(
            x1xx2yy_ids_, x1x2y_ids_, x1xx2yy_len_)
        x1xx2yyx1x2y_len_ = x1xx2yy_len_ + x1x2y_len_
        x1xx2yyx1x2y_max_len_ = np.max(x1xx2yyx1x2y_len_)
        x1xx2yyx1x2y_ids_ = x1xx2yyx1x2y_ids_[:, :x1xx2yyx1x2y_max_len_]

        x1xx2yyx1x2_len_ = x1xx2yy_len_ + batch_['x1x2_len']

        return {
            'x1xx2yyx1x2y_ids': x1xx2yyx1x2y_ids_,
            'x1xx2yyx1x2y_len': x1xx2yyx1x2y_len_,
            'x1xx2yyx1x2_len': x1xx2yyx1x2_len_
        }

    def _is_head():
        if not FLAGS.distributed:
            return True
        else:
            return hvd.rank() == 0

    def _train_epoch(sess, initial=False):
        """Trains on the training set, and evaluates on the dev set
        periodically.
        """
        iterator.restart_dataset(sess, 'train')

        while True:
            try:
                # (1) Get data and yy sample
                fetches_data = {
                    'batch': batch,
                    'batch_size': batch_size,
                }
                feed_dict_data = {
                    iterator.handle: iterator.get_handle(sess, 'train'),
                    tx.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                rets_data = sess.run(fetches_data, feed_dict_data)

                # (2) Optimize loss
                feed_dict = {
                    x1x2yx1my_ids: rets_data['batch']['x1x2yx1my_ids'],
                    x1x2yx1my_len: rets_data['batch']['x1x2yx1my_len'],
                    x1x2yx1m_len: rets_data['batch']['x1x2yx1m_len'],
                    x1x2yx1xx2_ids: rets_data['batch']['x1x2yx1xx2_ids'],
                    x1x2yx1xx2_len: rets_data['batch']['x1x2yx1xx2_len'],
                    #x1_ids: rets_data['batch']['x1_ids'],
                    x1_len: rets_data['batch']['x1_len'],
                    x1x2_ids: rets_data['batch']['x1x2_ids'],
                    x1x2_len: rets_data['batch']['x1x2_len'],
                    x1xx2_ids: rets_data['batch']['x1xx2_ids'],
                    x1xx2_len: rets_data['batch']['x1xx2_len'],
                    x1x2y_ids: rets_data['batch']['x1x2y_ids'],
                    x1x2y_len: rets_data['batch']['x1x2y_len'],
                    x1x2yx1xx2yy_ids: rets_data['batch']['x1x2yx1xx2yy_ids'],
                    x1x2yx1xx2yy_len: rets_data['batch']['x1x2yx1xx2yy_len'],
                    tau: config_train.tau,
                    tx.global_mode(): tf.estimator.ModeKeys.TRAIN,
                }

                if initial:
                    fetches_initial = {
                        'x1xx2yy_ids': x1xx2yy_ids,
                        'x1xx2yy_len': x1xx2yy_len
                    }
                    fetches_initial.update(loss_dict)
                    rets_initial = sess.run(fetches_initial, feed_dict)
                    if _is_head():
                        _log_losses(rets_initial, 0)
                    initial = False

                    for t in rets_initial['x1xx2yy_ids']:
                        t_text = proc.decode(t)
                        print(t_text)

                fetches = {
                    'train_op': train_op,
                    'step': global_step,
                }
                fetches.update(loss_dict)

                rets = sess.run(fetches, feed_dict)
                step = rets['step']

                dis_steps = config_train.display_steps

                if _is_head() and dis_steps > 0 and step % dis_steps == 0:
                    _log_losses(rets, step)

                eval_steps = config_train.eval_steps
                if _is_head() and eval_steps > 0 and step % eval_steps == 0:
                    _dev_epoch(sess)
                sample_steps = config_train.sample_steps
                if _is_head(
                ) and sample_steps > 0 and step % sample_steps == 0:
                    print('-----------testing-----------------')
                    _test_epoch(sess, step=step)

                ckpt_steps = config_train.checkpoint_steps
                if _is_head() and ckpt_steps > 0 and step % ckpt_steps == 0:
                    ckpt_fn = os.path.join(output_dir, 'model.ckpt')
                    ckpt_fn = saver.save(sess, ckpt_fn, global_step=step)
                    _log('Checkpoint to {}'.format(ckpt_fn))

            except tf.errors.OutOfRangeError:
                break

    def _dev_epoch(sess):
        """Evaluates on the dev set.
        """
        iterator.restart_dataset(sess, 'dev')

        results = tx.utils.AverageRecorder()
        nsamples = 0
        fetches = {}
        fetches.update(loss_dict)
        # i = 0

        while True:
            try:

                # (1) Get data and yy sample
                fetches_data = {
                    'batch': batch,
                    'batch_size': batch_size,
                    #'yy_ids': yy_ids,
                    #'yy_len': yy_len
                }
                feed_dict_data = {
                    iterator.handle: iterator.get_handle(sess, 'dev'),
                    tx.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                rets_data = sess.run(fetches_data, feed_dict_data)

                # (2) eval loss
                feed_dict = {
                    x1x2yx1my_ids: rets_data['batch']['x1x2yx1my_ids'],
                    x1x2yx1my_len: rets_data['batch']['x1x2yx1my_len'],
                    x1x2yx1m_len: rets_data['batch']['x1x2yx1m_len'],
                    x1x2yx1xx2_ids: rets_data['batch']['x1x2yx1xx2_ids'],
                    x1x2yx1xx2_len: rets_data['batch']['x1x2yx1xx2_len'],
                    x1_len: rets_data['batch']['x1_len'],
                    x1x2_ids: rets_data['batch']['x1x2_ids'],
                    x1x2_len: rets_data['batch']['x1x2_len'],
                    x1xx2_ids: rets_data['batch']['x1xx2_ids'],
                    x1xx2_len: rets_data['batch']['x1xx2_len'],
                    x1x2y_ids: rets_data['batch']['x1x2y_ids'],
                    x1x2y_len: rets_data['batch']['x1x2y_len'],
                    x1x2yx1xx2yy_ids: rets_data['batch']['x1x2yx1xx2yy_ids'],
                    x1x2yx1xx2yy_len: rets_data['batch']['x1x2yx1xx2yy_len'],
                    tau: config_train.tau,
                    tx.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }

                rets = sess.run(fetches, feed_dict)

                results.add(rets, weight=rets_data['batch_size'])
                nsamples += rets_data['batch_size']
            except tf.errors.OutOfRangeError:
                break

        _log_losses(results.avg())
        _log('nsamples: %d' % nsamples)

        avg_loss = results.avg('loss')
        if FLAGS.do_train and avg_loss < dev_best['loss']:
            dev_best.update(results.avg())
            ckpt_fn = os.path.join(output_dir, 'model_best.ckpt')
            ckpt_fn = saver_best.save(sess, ckpt_fn)
            _log('Checkpoint best to {}'.format(ckpt_fn))

    def _test_epoch(sess, step=None):
        """Generates samples on the test set.
        """
        iterator.restart_dataset(sess, 'test')

        _all_inputs = []
        _all_samples = []

        if FLAGS.finetune and FLAGS.roc:
            raise ValueError(
                'Cannot set --finetune and --roc at the same time')

        if FLAGS.finetune:
            _log('Generation input: x1xx2')
            fetches = {
                'inputs': batch['x1xx2_ids'],
                'length': batch['x1xx2_len'],
                'samples_length': yy_len_fine,
                'samples': yy_ids_fine
            }
            res_fn_appendix = "x1xx2"
        elif FLAGS.roc:
            _log('Generation input: x1x2')
            fetches = {
                'inputs': batch['x1x2_ids'],
                'length': batch['x1x2_len'],
                'samples_length': yy_len_roc,
                'samples': yy_ids_roc
            }
            res_fn_appendix = "x1x2"
        else:
            _log('Generation input: x1x2yx1xx2')
            fetches = {
                'inputs': batch['x1x2yx1xx2_ids'],
                'length': batch['x1x2yx1xx2_len'],
                'samples_length': yy_len,
                'samples': yy_ids
            }
            res_fn_appendix = "x1x2yx1xx2"

        counter = 0
        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'test'),
                    tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT,
                }
                rets = sess.run(fetches, feed_dict=feed_dict)
                # print(rets)
                counter += 1
                print(counter)
                _inputs = []
                for i, l in zip(rets['inputs'], rets['length']):
                    # Delete padding
                    _inputs.append(i[:l].tolist())
                _all_inputs.extend(_inputs)

                _samples = []
                for s, l in zip(rets['samples'], rets['samples_length']):
                    _samples.append(s[:l].tolist())
                _all_samples.extend(_samples)

                if counter >= 10:
                    break
            except tf.errors.OutOfRangeError:
                break

        # Parse samples and write to file

        eos_token_id = proc.encoder['<|endoftext|>']

        _all_input_text = []
        for i in _all_inputs:
            if i[0] == eos_token_id:
                i = i[1:]
            i_text = proc.decode(i)
            _all_input_text.append(i_text)
        _all_input_text = tx.utils.strip_eos(_all_input_text,
                                             eos_token='<|endoftext|>')

        _all_samples_text = []
        for i, s in zip(_all_inputs, _all_samples):
            s_text = proc.decode(s)
            s_text = s_text.replace('\n', ' ')
            _all_samples_text.append(s_text)
        print(_all_samples_text)

        if step is None:
            fn = "test_samples_%s.tsv" % res_fn_appendix
        else:
            fn = "test_samples_%s_%d.tsv" % (res_fn_appendix, step)
        output_file = os.path.join(output_dir, fn)
        _log('Write samples to {}'.format(output_file))
        tx.utils.write_paired_text(_all_input_text,
                                   _all_samples_text,
                                   output_file,
                                   mode='s')

    # Broadcasts global variables from rank-0 process
    if FLAGS.distributed:
        bcast = hvd.broadcast_global_variables(0)

    session_config = tf.ConfigProto()
    if FLAGS.distributed:
        session_config.gpu_options.visible_device_list = str(hvd.local_rank())

    with tf.Session(config=session_config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        smry_writer = tf.summary.FileWriter(FLAGS.output_dir, graph=sess.graph)

        if FLAGS.distributed:
            bcast.run()

        #Restores trained model if specified
        if FLAGS.checkpoint:
            _log('Restore from {}'.format(FLAGS.checkpoint))
            saver.restore(sess, FLAGS.checkpoint)
        elif FLAGS.pretrain_checkpoint:
            _log('Restore from {}'.format(FLAGS.pretrain_checkpoint))
            model_utils.init_gpt2_checkpoint(sess, FLAGS.pretrain_checkpoint)
            print("\nFinished loading\n")
            saver.save(sess, output_dir + '/gpt2_model.ckpt')

        iterator.initialize_dataset(sess)

        if FLAGS.do_train:
            for epoch in range(config_train.max_train_epoch):
                _train_epoch(sess, epoch == 0)
            saver.save(sess, output_dir + '/model.ckpt')

        if FLAGS.do_eval:
            _dev_epoch(sess)

        if FLAGS.do_test:
            _test_epoch(sess)
コード例 #37
0
def train():
    exception_box = ExceptionBox()

    if FLAGS.horovod:
        import horovod.tensorflow as hvd

    # Create training and validation datasets
    split_dataset = FLAGS.horovod

    train_set = create_dataset(FLAGS.train_files.split(','),
                               batch_size=FLAGS.train_batch_size,
                               epochs=FLAGS.epochs,
                               augmentations=Config.augmentations,
                               cache_path=FLAGS.feature_cache,
                               train_phase=True,
                               exception_box=exception_box,
                               process_ahead=Config.num_devices * FLAGS.train_batch_size * 2,
                               reverse=FLAGS.reverse_train,
                               limit=FLAGS.limit_train,
                               buffering=FLAGS.read_buffer,
                               split_dataset=split_dataset)

    iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
                                                 tfv1.data.get_output_shapes(train_set),
                                                 output_classes=tfv1.data.get_output_classes(train_set))

    # Make initialization ops for switching between the two sets
    train_init_op = iterator.make_initializer(train_set)

    if FLAGS.dev_files:
        dev_sources = FLAGS.dev_files.split(',')
        dev_sets = [create_dataset([source],
                                   batch_size=FLAGS.dev_batch_size,
                                   train_phase=False,
                                   exception_box=exception_box,
                                   process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
                                   reverse=FLAGS.reverse_dev,
                                   limit=FLAGS.limit_dev,
                                   buffering=FLAGS.read_buffer,
                                   split_dataset=split_dataset) for source in dev_sources]
        dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

    if FLAGS.metrics_files:
        metrics_sources = FLAGS.metrics_files.split(',')
        metrics_sets = [create_dataset([source],
                                       batch_size=FLAGS.dev_batch_size,
                                       train_phase=False,
                                       exception_box=exception_box,
                                       process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
                                       reverse=FLAGS.reverse_dev,
                                       limit=FLAGS.limit_dev,
                                       buffering=FLAGS.read_buffer,
                                       split_dataset=split_dataset) for source in metrics_sources]
        metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]

    # Dropout
    dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
    dropout_feed_dict = {
        dropout_rates[0]: FLAGS.dropout_rate,
        dropout_rates[1]: FLAGS.dropout_rate2,
        dropout_rates[2]: FLAGS.dropout_rate3,
        dropout_rates[3]: FLAGS.dropout_rate4,
        dropout_rates[4]: FLAGS.dropout_rate5,
        dropout_rates[5]: FLAGS.dropout_rate6,
    }
    no_dropout_feed_dict = {
        rate: 0. for rate in dropout_rates
    }

    # Building the graph
    learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
    reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
    if FLAGS.horovod:
        # Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.
        optimizer = create_optimizer(learning_rate_var * hvd.size())
        optimizer = hvd.DistributedOptimizer(optimizer)
    else:
        optimizer = create_optimizer(learning_rate_var)

    # Enable mixed precision training
    if FLAGS.automatic_mixed_precision:
        log_info('Enabling automatic mixed precision training.')
        optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)

    if FLAGS.horovod:
        loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False)
        gradients = optimizer.compute_gradients(loss)

        tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries'])
        log_grads_and_vars(gradients)

        # global_step is automagically incremented by the optimizer
        global_step = tfv1.train.get_or_create_global_step()
        apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step)
    else:
        gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)

        # Average tower gradients across GPUs
        avg_tower_gradients = average_gradients(gradients)
        log_grads_and_vars(avg_tower_gradients)

        # global_step is automagically incremented by the optimizer
        global_step = tfv1.train.get_or_create_global_step()
        apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

    # Summaries
    step_summaries_op = tfv1.summary.merge_all('step_summaries')
    step_summary_writers = {
        'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
        'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120),
        'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120),
    }

    human_readable_set_names = {
        'train': 'Training',
        'dev': 'Validation',
        'metrics': 'Metrics',
    }

    # Checkpointing
    if Config.is_master_process:
        checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
        checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')

        best_dev_saver = tfv1.train.Saver(max_to_keep=1)
        best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')

        # Save flags next to checkpoints
        if not is_remote_path(FLAGS.save_checkpoint_dir):
            os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
        flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
        with open_remote(flags_file, 'w') as fout:
            fout.write(FLAGS.flags_into_string())

    if FLAGS.horovod:
        bcast = hvd.broadcast_global_variables(0)

    with tfv1.Session(config=Config.session_config) as session:
        log_debug('Session opened.')

        # Prevent further graph changes
        tfv1.get_default_graph().finalize()

        # Load checkpoint or initialize variables
        load_or_init_graph_for_training(session)
        if FLAGS.horovod:
            bcast.run()

        def run_set(set_name, epoch, init_op, dataset=None):
            is_train = set_name == 'train'
            train_op = apply_gradient_op if is_train else []
            feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict

            total_loss = 0.0
            step_count = 0

            step_summary_writer = step_summary_writers.get(set_name)
            checkpoint_time = time.time()

            if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache:
                feature_cache_index = FLAGS.feature_cache + '.index'
                if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index):
                    log_info('Invalidating feature cache')
                    remove_remote(feature_cache_index)  # this will let TF also overwrite the related cache data files

            # Setup progress bar
            class LossWidget(progressbar.widgets.FormatLabel):
                def __init__(self):
                    progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')

                def __call__(self, progress, data, **kwargs):
                    data['mean_loss'] = total_loss / step_count if step_count else 0.0
                    return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)

            if Config.is_master_process:
                prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
                widgets = [' | ', progressbar.widgets.Timer(),
                           ' | Steps: ', progressbar.widgets.Counter(),
                           ' | ', LossWidget()]
                suffix = ' | Dataset: {}'.format(dataset) if dataset else None
                pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # Batch loop
            while True:
                try:
                    _, current_step, batch_loss, problem_files, step_summary = \
                        session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
                                    feed_dict=feed_dict)
                    exception_box.raise_if_set()
                except tf.errors.OutOfRangeError:
                    exception_box.raise_if_set()
                    break

                if problem_files.size > 0:
                    problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
                    log_error('The following files caused an infinite (or NaN) '
                              'loss: {}'.format(','.join(problem_files)))

                total_loss += batch_loss
                step_count += 1

                if Config.is_master_process:
                    pbar.update(step_count)

                    step_summary_writer.add_summary(step_summary, current_step)

                    if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
                        checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
                        checkpoint_time = time.time()

            if Config.is_master_process:
                pbar.finish()
            mean_loss = total_loss / step_count if step_count > 0 else 0.0
            return mean_loss, step_count

        log_info('STARTING Optimization')
        train_start_time = datetime.utcnow()
        best_dev_loss = float('inf')
        dev_losses = []
        epochs_without_improvement = 0
        try:
            for epoch in range(FLAGS.epochs):
                # Training
                if Config.is_master_process:
                    log_progress('Training epoch %d...' % epoch)
                train_loss, _ = run_set('train', epoch, train_init_op)
                if Config.is_master_process:
                    log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
                    checkpoint_saver.save(session, checkpoint_path, global_step=global_step)

                if FLAGS.dev_files:
                    # Validation
                    dev_loss = 0.0
                    total_steps = 0
                    for source, init_op in zip(dev_sources, dev_init_ops):
                        if Config.is_master_process:
                            log_progress('Validating epoch %d on %s...' % (epoch, source))
                        set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
                        dev_loss += set_loss * steps
                        total_steps += steps
                        if Config.is_master_process:
                            log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))

                    dev_loss = dev_loss / total_steps
                    dev_losses.append(dev_loss)

                    # Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
                    # the improvement has to be greater than FLAGS.es_min_delta
                    if dev_loss > best_dev_loss - FLAGS.es_min_delta:
                        epochs_without_improvement += 1
                    else:
                        epochs_without_improvement = 0

                    if Config.is_master_process:
                        # Save new best model
                        if dev_loss < best_dev_loss:
                            best_dev_loss = dev_loss
                            save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
                                                            latest_filename='best_dev_checkpoint')
                            log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))

                    # Early stopping
                    if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
                        if Config.is_master_process:
                            log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
                                epochs_without_improvement))
                        break

                    # Reduce learning rate on plateau
                    # If the learning rate was reduced and there is still no improvement
                    # wait FLAGS.plateau_epochs before the learning rate is reduced again
                    if (
                        FLAGS.reduce_lr_on_plateau
                        and epochs_without_improvement > 0
                        and epochs_without_improvement % FLAGS.plateau_epochs == 0
                    ):
                        # Reload checkpoint that we use the best_dev weights again
                        reload_best_checkpoint(session)

                        # Reduce learning rate
                        session.run(reduce_learning_rate_op)
                        current_learning_rate = learning_rate_var.eval()
                        if Config.is_master_process:
                            log_info('Encountered a plateau, reducing learning rate to {}'.format(
                                current_learning_rate))

                            # Overwrite best checkpoint with new learning rate value
                            save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
                                                            latest_filename='best_dev_checkpoint')
                            log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))

                if FLAGS.metrics_files:
                    # Read only metrics, not affecting best validation loss tracking
                    for source, init_op in zip(metrics_sources, metrics_init_ops):
                        if Config.is_master_process:
                            log_progress('Metrics for epoch %d on %s...' % (epoch, source))
                        set_loss, _ = run_set('metrics', epoch, init_op, dataset=source)
                        if Config.is_master_process:
                            log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))

                print('-' * 80)


        except KeyboardInterrupt:
            pass
        if Config.is_master_process:
            log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
    log_debug('Session closed.')
コード例 #38
0
ファイル: model.py プロジェクト: chinatian/glow
def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass
    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    loss_train, stats_train = f_loss(train_iterator, True)
    all_params = tf.trainable_variables()
    if hps.gradient_checkpointing == 1:
        from memory_saving_gradients import gradients
        gs = gradients(loss_train, all_params)
    else:
        gs = tf.gradients(loss_train, all_params)

    optimizer = {'adam': optim.adam, 'adamax': optim.adamax,
                 'adam2': optim.adam2}[hps.optimizer]

    train_op, polyak_swap_op, ema = optimizer(
        all_params, gs, alpha=lr, hps=hps)
    if hps.direct_iterator:
        m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
    else:
        def _train(_lr):
            _x, _y = train_iterator()
            return sess.run([train_op, stats_train], {feeds['x']: _x,
                                                      feeds['y']: _y, lr: _lr})[1]
        m.train = _train

    m.polyak_swap = lambda: sess.run(polyak_swap_op)

    # === Testing
    loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
    if hps.direct_iterator:
        m.test = lambda: sess.run(stats_test)
    else:
        def _test():
            _x, _y = test_iterator()
            return sess.run(stats_test, {feeds['x']: _x,
                                         feeds['y']: _y})
        m.test = _test

    # === Saving and restoring
    saver = tf.train.Saver()
    saver_ema = tf.train.Saver(ema.variables_to_restore())
    m.save_ema = lambda path: saver_ema.save(
        sess, path, write_meta_graph=False)
    m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
    m.restore = lambda path: saver.restore(sess, path)

    # === Initialize the parameters
    if hps.restore_path != '':
        m.restore(hps.restore_path)
    else:
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, True, reuse=True)
        sess.run(tf.global_variables_initializer())
        sess.run(results_init, {feeds['x']: data_init['x'],
                                feeds['y']: data_init['y']})
    sess.run(hvd.broadcast_global_variables(0))

    return m