Exemple #1
0
    def _add_train_op(self):
        """Sets self._train_op, the op to run for training."""
        # Take gradients of the trainable variables w.r.t. the loss function to minimize
        hps = self._hps
        tvars = tf.trainable_variables()
        loss_to_minimize = self._loss
        gradients = tf.gradients(
            loss_to_minimize,
            tvars,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)

        # Clip the gradients
        for gpu in get_available_gpus():
            with tf.device(gpu):
                grads, global_norm = tf.clip_by_global_norm(
                    gradients, hps.max_grad_norm)

        # Add a summary
        tf.summary.scalar('global_norm', global_norm)

        # Apply adagrad optimizer
        optimizer = tf.train.AdagradOptimizer(
            hps.lr, initial_accumulator_value=hps.adagrad_init_acc)
        for gpu in get_available_gpus():
            with tf.device(gpu):
                self._train_op = optimizer.apply_gradients(
                    zip(grads, tvars),
                    global_step=self.global_step,
                    name='train_step')
  def _add_train_op(self):
    """Sets self._train_op, the op to run for training."""
    hps = self._hps
    # Take gradients of the trainable variables w.r.t. the loss function to minimize
    loss_to_minimize = self._rewriter._total_loss if hps.coverage else self._rewriter._loss
    loss_to_minimize += (self._selector._loss * hps.selector_loss_wt)

    if hps.inconsistent_loss:
      loss_to_minimize += self._inconsistent_loss

    tvars = tf.trainable_variables()
    gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)

    # Clip the gradients
    for gpu in get_available_gpus():
      with tf.device(gpu):
        grads, global_norm = tf.clip_by_global_norm(gradients, hps.max_grad_norm)

    # Add a summary
    tf.summary.scalar('global_norm', global_norm)

    # Apply adagrad optimizer
    tf.logging.info('Using Adagrad optimizer')
    optimizer = tf.train.AdagradOptimizer(hps.lr, initial_accumulator_value=hps.adagrad_init_acc)
    for gpu in get_available_gpus():
      with tf.device(gpu):
        self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, name='train_step')
Exemple #3
0
 def build_graph(self):
     """Add the placeholders, model, global step, train_op and summaries to the graph"""
     tf.logging.info('Building graph...')
     t0 = time.time()
     self._add_placeholders()
     for gpu in get_available_gpus():
         with tf.device(gpu):
             self._add_seq2seq()
     self.global_step = tf.Variable(0, name='global_step', trainable=False)
     if self._hps.mode == 'train':
         self._add_train_op()
     self._summaries = tf.summary.merge_all()
     t1 = time.time()
     tf.logging.info('Time to build graph: %i seconds', t1 - t0)
 def build_graph(self):
   """Add the placeholders, model, global step, train_op and summaries to the graph"""
   tf.logging.info('Building graph...')
   t0 = time.time()
   self._selector._add_placeholders()
   self._rewriter._add_placeholders()
   for gpu in get_available_gpus():
     with tf.device(gpu):
       self._selector._add_sent_selector()
       
     self._rewriter._add_seq2seq(selector_probs=self._selector.probs)
     if self._hps.inconsistent_loss and self._rewriter._graph_mode != 'greedy_search':
       self._inconsistent_loss = self._add_inconsistent_loss()
       tf.summary.scalar('inconsist_loss', self._inconsistent_loss)
   self.global_step = tf.Variable(0, name='global_step', trainable=False)
   if self._hps.mode == 'train':
     self._add_train_op()
   self._summaries = tf.summary.merge_all()
   t1 = time.time()
   tf.logging.info('Time to build graph: %i seconds', t1 - t0)
Exemple #5
0
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_replicas = len(util.get_available_gpus())

    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        with tf.device(tf.DeviceSpec(device_type="GPU", device_index=i)):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
                replicas.append(StandardModel(config))

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=config.learning_rate)
    else:
        logging.error('No valid optimizer defined: {}'.format(
            config.optimizer))
        sys.exit(1)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [],
                                  initializer=init,
                                  trainable=False)

    if config.summaryFreq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, replicas, optimizer, global_step, writer)

    saver, progress = init_or_restore_variables(config, sess, train=True)

    global_step.load(progress.uidx, sess)

    #save model options
    config_as_dict = OrderedDict(sorted(vars(config).items()))
    json.dump(config_as_dict, open('%s.json' % config.saveto, 'wb'), indent=2)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in xrange(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error(
                    'Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'
                    .format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info(
                    'Minibatch with zero sample under length {0}'.format(
                        config.maxlen))
                continue
            write_summary_for_this_batch = config.summaryFreq and (
                (progress.uidx % config.summaryFreq == 0) or
                (config.finish_after
                 and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            loss = updater.update(sess, x_in, x_mask_in, y_in, y_mask_in,
                                  write_summary_for_this_batch)
            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            if config.dispFreq and progress.uidx % config.dispFreq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info(
                    '{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'
                    .format(disp_time, progress.eidx, progress.uidx,
                            total_loss / n_words, n_words / duration,
                            n_sents / duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sampleFreq and progress.uidx % config.sampleFreq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = replicas[0].sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beamFreq and progress.uidx % config.beamFreq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = replicas[0].beam_search(sess, x_small, x_mask_small,
                                                  config.beam_size)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost / len(sample))
                        logging.info(msg)

            if config.validFreq and progress.uidx % config.validFreq == 0:
                costs = validate(config, sess, valid_text_iterator,
                                 replicas[0])
                # validation loss is mean of normalized sentence log probs
                valid_loss = sum(costs) / len(costs)
                if (len(progress.history_errs) == 0
                        or valid_loss < min(progress.history_errs)):
                    progress.history_errs.append(valid_loss)
                    progress.bad_counter = 0
                    saver.save(sess, save_path=config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_loss)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    score = validate_with_script(sess, replicas[0], config,
                                                 valid_text_iterator)
                    need_to_save = (
                        score is not None
                        and (len(progress.valid_script_scores) == 0
                             or score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        save_path = config.saveto + ".best-valid-script"
                        saver.save(sess, save_path=save_path)
                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.saveFreq and progress.uidx % config.saveFreq == 0:
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                progress.estop = True
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
Exemple #6
0
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(util.get_available_gpus())
    num_replicas = max(1, num_gpus)

    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [],
                                  initializer=init,
                                  trainable=False)

    if config.learning_schedule == "constant":
        schedule = ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = TransformerSchedule(global_step=global_step,
                                       dim=config.state_size,
                                       warmup_steps=config.warmup_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=schedule.learning_rate,
            beta1=config.adam_beta1,
            beta2=config.adam_beta2,
            epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(
            config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    if config.exponential_smoothing > 0.0:
        smoothing = ExponentialSmoothing(config.exponential_smoothing)

    saver, progress = model_loader.init_or_restore_variables(config,
                                                             sess,
                                                             train=True)

    global_step.load(progress.uidx, sess)

    # Use an InferenceModelSet to abstract over model types for sampling and
    # beam search. Multi-GPU sampling and beam search are not currently
    # supported, so we just use the first replica.
    model_set = inference.InferenceModelSet([replicas[0]], [config])

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error(
                    'Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'
                    .format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info(
                    'Minibatch with zero sample under length {0}'.format(
                        config.maxlen))
                continue
            write_summary_for_this_batch = config.summary_freq and (
                (progress.uidx % config.summary_freq == 0) or
                (config.finish_after
                 and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            loss = updater.update(sess, x_in, x_mask_in, y_in, y_mask_in,
                                  write_summary_for_this_batch)
            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            # Update the smoothed version of the model variables.
            # To reduce the performance overhead, we only do this once every
            # N steps (the smoothing factor is adjusted accordingly).
            if config.exponential_smoothing > 0.0 and progress.uidx % smoothing.update_frequency == 0:
                sess.run(fetches=smoothing.update_ops)

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info(
                    '{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'
                    .format(disp_time, progress.eidx, progress.uidx,
                            total_loss / n_words, n_words / duration,
                            n_sents / duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :
                                                      10], x_mask_in[:, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.beam_search(
                    sess,
                    x_small,
                    x_mask_small,
                    config.beam_size,
                    normalization_alpha=config.normalization_alpha)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost / len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                if config.exponential_smoothing > 0.0:
                    sess.run(fetches=smoothing.swap_ops)
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                    sess.run(fetches=smoothing.swap_ops)
                else:
                    valid_ce = validate(sess, replicas[0], config,
                                        valid_text_iterator)
                if (len(progress.history_errs) == 0
                        or valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    if config.exponential_smoothing > 0.0:
                        sess.run(fetches=smoothing.swap_ops)
                        score = validate_with_script(sess, replicas[0], config)
                        sess.run(fetches=smoothing.swap_ops)
                    else:
                        score = validate_with_script(sess, replicas[0], config)
                    need_to_save = (
                        score is not None
                        and (len(progress.valid_script_scores) == 0
                             or score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop = True
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
def main():
    #Constants
    DATASET_PATH = os.path.join(".")
    LEARNING_RATE_1 = 0.0001
    EPOCHS = 2
    BATCH_SIZE = 32
    NUM_CLASSES = 48
    Z_SCORE = 1.96
    WEIGHT_DECAY_1 = 0.0005

    print("Current Setup:-")
    print(
        "Starting Learning Rate: {}, Epochs: {}, Batch Size: {}, Confidence Interval Z-Score {}, Number of classes: {}, Starting Weight Decay: {}"
        .format(LEARNING_RATE_1, EPOCHS, BATCH_SIZE, Z_SCORE, NUM_CLASSES,
                WEIGHT_DECAY_1))

    #Get the number of GPUs
    NUM_GPUS = util.get_available_gpus()

    print("Number of GPUs available : {}".format(NUM_GPUS))
    with tf.device('/cpu:0'):
        tower_grads = []
        reuse_vars = False
        dataset_len = 1207350

        #Placeholders
        learning_rate = tf.placeholder(tf.float32,
                                       shape=[],
                                       name='learning_rate')
        weight_decay = tf.placeholder(tf.float32,
                                      shape=[],
                                      name="weight_decay")

        for i in range(NUM_GPUS):
            with tf.device(
                    util.assign_to_device('/gpu:{}'.format(i),
                                          ps_device='/cpu:0')):

                #Need to split data between GPUs
                train_features, train_labels, train_filenames = util.train_input_fn(
                    DATASET_PATH, BATCH_SIZE, EPOCHS)
                print("At GPU {}, Train Features : {}".format(
                    i, train_features))

                #Model
                _, train_op, tower_grads, train_cross_entropy, train_conf_matrix_op, train_accuracy, reuse_vars = initiate_vgg_model(
                    train_features,
                    train_labels,
                    train_filenames,
                    NUM_CLASSES,
                    weight_decay,
                    learning_rate,
                    reuse=reuse_vars,
                    tower_grads=tower_grads,
                    gpu_num=i,
                    handle="training")
                #tf.summary.scalar("training_confusion_matrix", tf.reshape(tf.cast(conf_matrix_op, tf.float32),[1, NUM_CLASSES, NUM_CLASSES, 1]))

        tower_grads = util.average_gradients(tower_grads)
        train_op = train_op.apply_gradients(tower_grads)

        saver = tf.train.Saver()

        if not os.path.exists(os.path.join("./multi_dl_research_train/")):
            os.mkdir(os.path.join("./multi_dl_research_train/"))

        with tf.Session() as sess:
            with np.printoptions(threshold=np.inf):
                writer = tf.summary.FileWriter("./multi_tensorboard_logs/")
                writer.add_graph(sess.graph)
                merged_summary = tf.summary.merge_all()
                train_highest_acc = 0
                sess.run([
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ])

                for epoch in range(EPOCHS):
                    if epoch == 18:
                        LEARNING_RATE_1 = 0.00005
                        print("Learning Rate changed to {} at epoch {}".format(
                            LEARNING_RATE_1, epoch))
                    elif epoch == 29:
                        LEARNING_RATE_1 = 0.00001
                        WEIGHT_DECAY_1 = 0.0
                        print("Learning Rate changed to {} at epoch {}".format(
                            LEARNING_RATE_1, epoch))
                        print("Weight Decay changed to {} at epoch {}".format(
                            WEIGHT_DECAY_1, epoch))
                    elif epoch == 42:
                        LEARNING_RATE_1 = 0.000005
                        print("Learning Rate changed to {} at epoch {}".format(
                            LEARNING_RATE_1, epoch))
                    elif epoch == 51:
                        LEARNING_RATE_1 = 0.000001
                        print("Learning Rate changed to {} at epoch {}".format(
                            LEARNING_RATE_1, epoch))

                    print("Current Epoch: {}".format(epoch))
                    for i in range(2):
                        print("Current Training Iteration : {}/{}".format(
                            i, 10))
                        train_acc, _, _, train_ce, train_summary = util.training(
                            BATCH_SIZE, NUM_CLASSES, learning_rate,
                            weight_decay, sess, train_op, train_conf_matrix_op,
                            LEARNING_RATE_1, WEIGHT_DECAY_1,
                            train_cross_entropy, merged_summary,
                            train_accuracy)
                        train_value1, train_value2 = util.confidence_interval(
                            train_acc, Z_SCORE, 32)
                        print("Training Accuracy : {}".format(train_acc))
                        print("Training Loss (Cross Entropy) : {}".format(
                            train_ce))
                        print("Training Confidence Interval: [{} , {}]".format(
                            train_value2, train_value1))
                        if train_highest_acc <= train_acc:
                            train_highest_acc = train_acc
                            print(
                                "Highest Training Accuracy Reached: {}".format(
                                    train_highest_acc))
                            #For every epoch, we will save the model
                            saver.save(
                                sess,
                                os.path.join("./multi_dl_research_train/",
                                             "model.ckpt"))
                            print(
                                "Latest Model is saving and Tensorboard Logs are updated"
                            )
                        writer.add_summary(
                            train_summary,
                            epoch * int((dataset_len * 0.8) / BATCH_SIZE) + i)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# The GPU id to use, usually either "0" or "1";
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
tf.random.set_seed(1234)

with open('config.yaml') as f:
    train_config = yaml.load(f, Loader=yaml.FullLoader)["train"]

MAX_LENGTH = int(train_config["MAX_LENGTH"])
BATCH_SIZE = int(train_config["BATCH_SIZE"])
BUFFER_SIZE = int(train_config["BUFFER_SIZE"])
EPOCHS = int(train_config["EPOCHS"])
language = train_config["language"]

# utilizing multiple GPUs
num_gpus = get_available_gpus()
if len(num_gpus) > 0:
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        [dataset_train, VOCAB_SIZE, tokenizer, START_TOKEN,
         END_TOKEN] = getDataset(MAX_LENGTH, BUFFER_SIZE, BATCH_SIZE)
        print(dataset_train)

        emb_matrix = load_embeddings(vocab_size=VOCAB_SIZE,
                                     tokenizer=tokenizer,
                                     language=language)

        Transformer = TransformerModel(max_length=MAX_LENGTH,
                                       vocab_size=VOCAB_SIZE,
                                       embedding_matrix=emb_matrix)
        model = Transformer.model
Exemple #9
0
def train(config, sess):
    ####################################################
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(util.get_available_gpus())
    num_replicas = max(1, num_gpus)

    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [],
                                  initializer=init,
                                  trainable=False)

    if config.learning_schedule == "constant":
        schedule = ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = TransformerSchedule(global_step=global_step,
                                       dim=config.state_size,
                                       warmup_steps=config.warmup_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(
            learning_rate=schedule.learning_rate,
            beta1=config.adam_beta1,
            beta2=config.adam_beta2,
            epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(
            config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    saver, progress = model_loader.init_or_restore_variables(config,
                                                             sess,
                                                             train=True)

    ############################################################
    #add: pretrain
    if config.pretrain:
        logging.info("Start pre-training")
        #预训练网络参数
        pre_batch_size = 1000
        epochs = 20
        pre_learning_rate = 0.001
        pre_optimizer = tf.train.GradientDescentOptimizer(
            pre_learning_rate).minimize(replicas[0].loss_pre_train)
        #加载预训练数据及相关字典
        gvocab, gvectors = util.pre_load_data(config.pretrain_vocab,
                                              config.pretrain_vectors)
        pre_vocab_list = list(gvocab.keys())
        #过采样
        pre_train_list = []
        with open('/media/ntfs-3/EXP/MULTI/mix/zh-en/data3/glove/vocab.txt',
                  'r',
                  encoding='utf-8') as f:
            for line in f:
                k, v = line.strip().split()
                pre_train_list.extend([k] * int(v))
        utf8_dict = json.load(
            open(config.source_dicts[0], 'r', encoding='utf-8'))
        embedding_list = []
        #开始训练
        for i in range(epochs):
            logging.info("epoch:{}".format(i))
            if i == epochs - 1:
                source_x, source_y, _vocab = util.get_data(pre_vocab_list,
                                                           pre_batch_size,
                                                           gvocab,
                                                           gvectors,
                                                           utf8_dict,
                                                           shuffle=False)
            else:
                source_x, source_y, _vocab = util.get_data(pre_train_list,
                                                           pre_batch_size,
                                                           gvocab,
                                                           gvectors,
                                                           utf8_dict,
                                                           shuffle=True)
            for idx, [s_x, s_y] in enumerate(zip(source_x, source_y)):
                assert len(s_x) == len(s_y), "{}, {}".format(
                    len(s_x), len(s_y))
                sx, sy = util.pre_prepare_data(s_x, s_y)
                feed_dict = {}
                feed_dict[replicas[0].pre_inputs.x] = sx
                feed_dict[replicas[0].pre_inputs.y] = sy
                _, loss, embedding = sess.run([
                    pre_optimizer, replicas[0].loss_pre_train,
                    replicas[0].pre_embedding
                ],
                                              feed_dict=feed_dict)
                if idx % 100 == 0:
                    logging.info("loss:{}".format(loss))
                if i == epochs - 1:
                    embedding_list.append(embedding)
        assert _vocab == pre_vocab_list
        emb = embedding_list[0]
        for e in embedding_list[1:]:
            emb = numpy.concatenate((emb, e))
        numpy.save("pre_emb/pre_emb.npy", emb)
        with open("pre_emb/vocab", "w", encoding="utf-8") as f:
            f.write("\n".join(pre_vocab_list))
        #tsne可视化
        tsne = util.get_tsne(emb, "pre_emb/tsne.npy")
        gtsne = numpy.load(config.pretrain_tsne)
        #util.plot_tsne(_vocab, tsne, gvocab, gtsne, top=20)
        #exit(0)
    ##################################################################################

    global_step.load(progress.uidx, sess)

    # Use an InferenceModelSet to abstract over model types for sampling and
    # beam search. Multi-GPU sampling and beam search are not currently
    # supported, so we just use the first replica.
    model_set = inference.InferenceModelSet([replicas[0]], [config])

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for pre_source_sents, source_sents, target_sents in text_iterator:
            #if len(source_sents[0][0]) != config.factors:
            #logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0])))
            #sys.exit(1)

            px_in, x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents,
                target_sents,
                config.factors,
                pre_source_sents,
                maxlen=None)

            if x_in is None:
                logging.info(
                    'Minibatch with zero sample under length {0}'.format(
                        config.maxlen))
                continue

            write_summary_for_this_batch = config.summary_freq and (
                (progress.uidx % config.summary_freq == 0) or
                (config.finish_after
                 and progress.uidx % config.finish_after == 0))
            (factors, seqLen, uLen, batch_size) = x_in.shape

            loss = updater.update(sess, px_in, x_in, x_mask_in, y_in,
                                  y_mask_in, write_summary_for_this_batch)

            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info(
                    '{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'
                    .format(disp_time, progress.eidx, progress.uidx,
                            total_loss / n_words, n_words / duration,
                            n_sents / duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :, :
                                                      10], x_mask_in[:, :, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    #source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    #logging.info('SOURCE: {}'.format(source))
                    #logging.info('SOURCE: {}'.format(xx))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :, :
                                                      10], x_mask_in[:, :, :
                                                                     10], y_in[:, :
                                                                               10]
                samples = model_set.beam_search(
                    sess,
                    x_small,
                    x_mask_small,
                    config.beam_size,
                    normalization_alpha=config.normalization_alpha)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(
                    y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    #source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    #logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost / len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                valid_ce = validate(sess, replicas[0], config,
                                    valid_text_iterator)
                if (len(progress.history_errs) == 0
                        or valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    score = validate_with_script(sess, replicas[0], config)
                    need_to_save = (
                        score is not None
                        and (len(progress.valid_script_scores) == 0
                             or score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess,
                           save_path=config.saveto,
                           global_step=progress.uidx)
                write_config_to_json_file(
                    config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop = True
                progress_path = '{0}-{1}.progress.json'.format(
                    config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break
Exemple #10
0
def train(config, sess):
    assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \
    "MAP training requires a prior model file: Use command-line option --prior_model"

    # Construct the graph, with one model replica per GPU

    num_gpus = len(util.get_available_gpus())
    num_replicas = max(1, num_gpus)

    logging.info('Building model...')
    replicas = []
    for i in range(num_replicas):
        device_type = "GPU" if num_gpus > 0 else "CPU"
        device_spec = tf.DeviceSpec(device_type=device_type, device_index=i)
        with tf.device(device_spec):
            with tf.variable_scope(tf.get_variable_scope(), reuse=(i>0)):
                if config.model_type == "transformer":
                    model = TransformerModel(config)
                else:
                    model = rnn_model.RNNModel(config)
                replicas.append(model)

    init = tf.zeros_initializer(dtype=tf.int32)
    global_step = tf.get_variable('time', [], initializer=init, trainable=False)

    if config.learning_schedule == "constant":
        schedule = ConstantSchedule(config.learning_rate)
    elif config.learning_schedule == "transformer":
        schedule = TransformerSchedule(global_step=global_step,
                                       dim=config.state_size,
                                       warmup_steps=config.warmup_steps)
    else:
        logging.error('Learning schedule type is not valid: {}'.format(
            config.learning_schedule))
        sys.exit(1)

    if config.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=schedule.learning_rate,
                                           beta1=config.adam_beta1,
                                           beta2=config.adam_beta2,
                                           epsilon=config.adam_epsilon)
    else:
        logging.error('No valid optimizer defined: {}'.format(config.optimizer))
        sys.exit(1)

    if config.summary_freq:
        summary_dir = (config.summary_dir if config.summary_dir is not None
                       else os.path.abspath(os.path.dirname(config.saveto)))
        writer = tf.summary.FileWriter(summary_dir, sess.graph)
    else:
        writer = None

    updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step,
                           writer)

    saver, progress = model_loader.init_or_restore_variables(
        config, sess, train=True)

    global_step.load(progress.uidx, sess)

    # Use an InferenceModelSet to abstract over model types for sampling and
    # beam search. Multi-GPU sampling and beam search are not currently
    # supported, so we just use the first replica.
    model_set = inference.InferenceModelSet([replicas[0]], [config])

    #save model options
    write_config_to_json_file(config, config.saveto)

    text_iterator, valid_text_iterator = load_data(config)
    _, _, num_to_source, num_to_target = util.load_dictionaries(config)
    total_loss = 0.
    n_sents, n_words = 0, 0
    last_time = time.time()
    logging.info("Initial uidx={}".format(progress.uidx))
    for progress.eidx in range(progress.eidx, config.max_epochs):
        logging.info('Starting epoch {0}'.format(progress.eidx))
        for source_sents, target_sents in text_iterator:
            if len(source_sents[0][0]) != config.factors:
                logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0])))
                sys.exit(1)
            x_in, x_mask_in, y_in, y_mask_in = util.prepare_data(
                source_sents, target_sents, config.factors, maxlen=None)
            if x_in is None:
                logging.info('Minibatch with zero sample under length {0}'.format(config.maxlen))
                continue
            write_summary_for_this_batch = config.summary_freq and ((progress.uidx % config.summary_freq == 0) or (config.finish_after and progress.uidx % config.finish_after == 0))
            (factors, seqLen, batch_size) = x_in.shape

            loss = updater.update(sess, x_in, x_mask_in, y_in, y_mask_in,
                                  write_summary_for_this_batch)
            total_loss += loss
            n_sents += batch_size
            n_words += int(numpy.sum(y_mask_in))
            progress.uidx += 1

            if config.disp_freq and progress.uidx % config.disp_freq == 0:
                duration = time.time() - last_time
                disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]')
                logging.info('{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'.format(disp_time, progress.eidx, progress.uidx, total_loss/n_words, n_words/duration, n_sents/duration))
                last_time = time.time()
                total_loss = 0.
                n_sents = 0
                n_words = 0

            if config.sample_freq and progress.uidx % config.sample_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :10], x_mask_in[:, :10], y_in[:, :10]
                samples = model_set.sample(sess, x_small, x_mask_small)
                assert len(samples) == len(x_small.T) == len(y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    sample = util.seq2words(ss, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    logging.info('SAMPLE: {}'.format(sample))

            if config.beam_freq and progress.uidx % config.beam_freq == 0:
                x_small, x_mask_small, y_small = x_in[:, :, :10], x_mask_in[:, :10], y_in[:,:10]
                samples = model_set.beam_search(sess, x_small, x_mask_small,
                                               config.beam_size,
                                               normalization_alpha=config.normalization_alpha)
                # samples is a list with shape batch x beam x len
                assert len(samples) == len(x_small.T) == len(y_small.T), (len(samples), x_small.shape, y_small.shape)
                for xx, yy, ss in zip(x_small.T, y_small.T, samples):
                    source = util.factoredseq2words(xx, num_to_source)
                    target = util.seq2words(yy, num_to_target)
                    logging.info('SOURCE: {}'.format(source))
                    logging.info('TARGET: {}'.format(target))
                    for i, (sample_seq, cost) in enumerate(ss):
                        sample = util.seq2words(sample_seq, num_to_target)
                        msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format(
                            i, sample, cost, len(sample), cost/len(sample))
                        logging.info(msg)

            if config.valid_freq and progress.uidx % config.valid_freq == 0:
                valid_ce = validate(sess, replicas[0], config,
                                    valid_text_iterator)
                if (len(progress.history_errs) == 0 or
                    valid_ce < min(progress.history_errs)):
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter = 0
                    save_non_checkpoint(sess, saver, config.saveto)
                    progress_path = '{0}.progress.json'.format(config.saveto)
                    progress.save_to_json(progress_path)
                else:
                    progress.history_errs.append(valid_ce)
                    progress.bad_counter += 1
                    if progress.bad_counter > config.patience:
                        logging.info('Early Stop!')
                        progress.estop = True
                        break
                if config.valid_script is not None:
                    score = validate_with_script(sess, replicas[0], config)
                    need_to_save = (score is not None and
                        (len(progress.valid_script_scores) == 0 or
                         score > max(progress.valid_script_scores)))
                    if score is None:
                        score = 0.0  # ensure a valid value is written
                    progress.valid_script_scores.append(score)
                    if need_to_save:
                        progress.bad_counter = 0
                        save_path = config.saveto + ".best-valid-script"
                        save_non_checkpoint(sess, saver, save_path)
                        write_config_to_json_file(config, save_path)

                        progress_path = '{}.progress.json'.format(save_path)
                        progress.save_to_json(progress_path)

            if config.save_freq and progress.uidx % config.save_freq == 0:
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)

            if config.finish_after and progress.uidx % config.finish_after == 0:
                logging.info("Maximum number of updates reached")
                saver.save(sess, save_path=config.saveto, global_step=progress.uidx)
                write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx))

                progress.estop=True
                progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx)
                progress.save_to_json(progress_path)
                break
        if progress.estop:
            break