示例#1
0
文件: main.py 项目: ml-jku/gapnet-pl
def main(_):
    config = Config()
    np.random.seed(config.get_value("random_seed", 12345))

    # PARAMETERS
    n_epochs = config.get_value("epochs", 100)
    batchsize = config.get_value("batchsize", 8)
    n_classes = config.get_value("n_classes", 13)
    dropout = config.get_value("dropout", 0.25)  # TODO
    num_threads = config.get_value("num_threads", 5)
    initial_val = config.get_value("initial_val", True)

    # READER, LOADER
    readers = invoke_dataset_from_config(config)
    reader_train = readers["train"]
    reader_val = readers["val"]
    train_loader = torch.utils.data.DataLoader(reader_train,
                                               batch_size=config.batchsize,
                                               shuffle=True,
                                               num_workers=num_threads)
    val_loader = torch.utils.data.DataLoader(reader_val,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=num_threads)

    # CONFIG
    tell = TeLLSession(config=config,
                       model_params={"shape": reader_train.shape})
    # Get some members from the session for easier usage
    session = tell.tf_session
    model = tell.model
    workspace, config = tell.workspace, tell.config

    prediction = tf.sigmoid(model.output)
    prediction_val = tf.reduce_mean(tf.sigmoid(model.output),
                                    axis=0,
                                    keepdims=True)

    # LOSS
    if hasattr(model, "loss"):
        loss = model.loss()
    else:
        with tf.name_scope("Loss_per_Class"):
            loss = 0
            for i in range(n_classes):
                loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=model.output[:, i], labels=model.y_[:, i])
                loss_mean = tf.reduce_mean(loss_batch)
                loss += loss_mean

    # Validation loss after patching
    if hasattr(model, "loss"):
        loss_val = model.loss()
    else:
        with tf.name_scope("Loss_per_Class_Patching"):
            loss_val = 0
            for i in range(n_classes):
                loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=tf.reduce_mean(model.output[:, i],
                                          axis=0,
                                          keepdims=True),
                    labels=model.y_[:, i])
                loss_mean = tf.reduce_mean(loss_batch)
                loss_val += loss_mean

    # REGULARIZATION
    reg_penalty = regularize(layers=model.layers,
                             l1=config.l1,
                             l2=config.l2,
                             regularize_weights=True,
                             regularize_biases=True)

    # LEARNING RATE (SCHEDULE)
    # if a LRS is defined always use MomentumOptimizer and pass learning rate to optimizer
    lrs_plateu = False
    if config.get_value("lrs", None) is not None:
        lr_sched_type = config.lrs["type"]
        if lr_sched_type == "plateau":
            lrs_plateu = True
            learning_rate = tf.placeholder(tf.float32, [],
                                           name='learning_rate')
            lrs_learning_rate = config.get_value(
                "optimizer_params")["learning_rate"]
            lrs_n_bad_epochs = 0  # counter for plateu LRS
            lrs_patience = config.lrs["patience"]
            lrs_factor = config.lrs["factor"]
            lrs_threshold = config.lrs["threshold"]
            lrs_mode = config.lrs["mode"]
            lrs_best = -np.inf if lrs_mode == "max" else np.inf
            lrs_is_better = lambda old, new: (new > old * (
                1 + lrs_threshold)) if lrs_mode == "max" else (new < old * (
                    1 - lrs_threshold))
    else:
        learning_rate = None  # if no LRS is defined the default optimizer is used with its defined learning rate

    # LOAD WEIGHTS and get list of trainables if specified
    assign_loaded_variables = None
    trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    if config.get_value("checkpoint", None) is not None:
        with Timer(name="Loading Checkpoint", verbose=True):
            assign_loaded_variables, trainables = tell.load_weights(
                config.get_value("checkpoint", None),
                config.get_value("freeze", False),
                config.get_value("exclude_weights", None),
                config.get_value("exclude_freeze", None))

    # Update step
    if len(trainables) > 0:
        update, gradients, gradient_name_dict = update_step(
            loss + reg_penalty,
            config,
            tell,
            lr=learning_rate,
            trainables=trainables)

    # INITIALIZE Tensorflow VARIABLES
    step = tell.initialize_tf_variables().global_step

    # ASSING LOADED WEIGHTS (overriding initializations) if available
    if assign_loaded_variables is not None:
        session.run(assign_loaded_variables)

    # -------------------------------------------------------------------------
    # Start training
    # -------------------------------------------------------------------------
    try:
        n_mbs = len(train_loader)
        epoch = int((step * batchsize) / (n_mbs * batchsize))
        epochs = range(epoch, n_epochs)

        if len(trainables) == 0:
            validate(val_loader, n_classes, session, loss_val, prediction_val,
                     model, workspace, step, batchsize, tell)
            return

        print("Epoch: {}/{} (step: {}, nmbs: {}, batchsize: {})".format(
            epoch + 1, n_epochs, step, n_mbs, batchsize))
        for ep in epochs:
            if ep == 0 and initial_val:
                f1 = validate(val_loader, n_classes, session, loss_val,
                              prediction_val, model, workspace, step,
                              batchsize, tell)
            else:
                if config.has_value("lrs_best") and config.has_value(
                        "lrs_learning_rate") and config.has_value(
                            "lrs_n_bad_epochs"):
                    f1 = config.get_value("lrs_f1")
                    lrs_best = config.get_value("lrs_best")
                    lrs_learning_rate = config.get_value("lrs_learning_rate")
                    lrs_n_bad_epochs = config.get_value("lrs_n_bad_epochs")
                else:
                    f1 = 0

            # LRS "Plateu"
            if lrs_plateu:
                # update scheduler
                if lrs_is_better(lrs_best, f1):
                    lrs_best = f1
                    lrs_n_bad_epochs = 0
                else:
                    lrs_n_bad_epochs += 1
                # update learning rate
                if lrs_n_bad_epochs > lrs_patience:
                    lrs_learning_rate = max(lrs_learning_rate * lrs_factor, 0)
                    lrs_n_bad_epochs = 0

            with tqdm(total=len(train_loader),
                      desc="Training [{}/{}]".format(ep + 1,
                                                     len(epochs))) as pbar:
                for mbi, mb in enumerate(train_loader):
                    # LRS "Plateu"
                    if lrs_plateu:
                        feed_dict = {
                            model.X: mb['input'].numpy(),
                            model.y_: mb['target'].numpy(),
                            model.dropout: dropout,
                            learning_rate: lrs_learning_rate
                        }
                    else:
                        feed_dict = {
                            model.X: mb['input'].numpy(),
                            model.y_: mb['target'].numpy(),
                            model.dropout: dropout
                        }

                    # TRAINING
                    pred, loss_train, _ = session.run(
                        [prediction, loss, update], feed_dict=feed_dict)

                    # Update status
                    pbar.set_description_str(
                        "Training [{}/{}] Loss: {:.4f}".format(
                            ep + 1, len(epochs), loss_train))
                    pbar.update()
                    step += 1

            validate(val_loader, n_classes, session, loss_val, prediction_val,
                     model, workspace, step, batchsize, tell)
    except AbortRun:
        print("Aborting...")
    finally:
        tell.close(global_step=step, save_checkpoint=True)
示例#2
0
def main(_):

    # ------------------------------------------------------------------------------------------------------------------
    # Setup training
    # ------------------------------------------------------------------------------------------------------------------

    # Initialize config, parses command line and reads specified file; also supports overriding of values from cmd
    config = Config()

    #
    # Load datasets for training and validation
    #
    with Timer(name="Loading Data", verbose=True):
        # Make sure datareader is reproducible
        random_seed = config.get_value('random_seed', 12345)
        np.random.seed(
            random_seed)  # not threadsafe, use rnd_gen object where possible
        rnd_gen = np.random.RandomState(seed=random_seed)

        print("Loading training data...")
        trainingset = MovingDotDataset(n_timesteps=5,
                                       n_samples=50,
                                       batchsize=config.batchsize,
                                       rnd_gen=rnd_gen)
        print("Loading validation data...")
        validationset = MovingDotDataset(n_timesteps=5,
                                         n_samples=25,
                                         batchsize=config.batchsize,
                                         rnd_gen=rnd_gen)

    #
    # Initialize TeLL session
    #
    tell = TeLLSession(config=config,
                       summaries=["train", "validation"],
                       model_params={"dataset": trainingset})

    # Get some members from the session for easier usage
    sess = tell.tf_session
    summary_writer_train, summary_writer_validation = tell.tf_summaries[
        "train"], tell.tf_summaries["validation"]
    model = tell.model
    workspace, config = tell.workspace, tell.config

    #
    # Define loss functions and update steps
    #
    print("Initializing loss calculation...")
    pos_target_weight = np.prod(
        trainingset.y_shape[2:]
    ) - 1  # only 1 pixel per sample is of positive class -> up-weight!
    loss = tf.reduce_mean(
        tf.nn.weighted_cross_entropy_with_logits(targets=model.y_,
                                                 logits=model.output,
                                                 pos_weight=pos_target_weight))
    # loss = tf.reduce_mean(-tf.reduce_sum((model.y_ * tf.log(model.output)) *
    #                                      -tf.reduce_sum(model.y_ - 1) / tf.reduce_sum(model.y_),
    #                                      axis=[1, 2, 3, 4]))
    train_summary = tf.summary.scalar(
        "Training Loss", loss)  # create summary to add to tensorboard

    # Loss function for validationset
    val_loss = tf.reduce_mean(
        tf.nn.weighted_cross_entropy_with_logits(targets=model.y_,
                                                 logits=model.output,
                                                 pos_weight=pos_target_weight))
    # val_loss = tf.reduce_mean(-tf.reduce_sum(model.y_ * tf.log(model.output) *
    #                                          -tf.reduce_sum(model.y_ - 1) / tf.reduce_sum(model.y_),
    #                                          axis=[1, 2, 3, 4]))
    val_loss_summary = tf.summary.scalar(
        "Validation Loss", val_loss)  # create summary to add to tensorboard

    # Regularization
    reg_penalty = regularize(layers=model.get_layers(),
                             l1=config.l1,
                             l2=config.l2,
                             regularize_weights=True,
                             regularize_biases=True)
    regpen_summary = tf.summary.scalar(
        "Regularization Penalty",
        reg_penalty)  # create summary to add to tensorboard

    # Update step for weights
    update = update_step(loss + reg_penalty, config)

    #
    # Prepare plotting
    #
    plot_elements_sym = list(model.get_plot_dict().values())
    plot_elements = list()
    plot_ranges = model.get_plot_range_dict()

    #
    # Initialize tensorflow variables (either initializes them from scratch or restores from checkpoint)
    #
    global_step = tell.initialize_tf_variables().global_step

    #
    # Finalize graph
    #  This makes our tensorflow graph read-only and prevents further additions to the graph
    #
    sess.graph.finalize()
    if sess.graph.finalized:
        print("Graph is finalized!")
    else:
        raise ValueError("Could not finalize graph!")

    sys.stdout.flush()

    # ------------------------------------------------------------------------------------------------------------------
    # Start training
    # ------------------------------------------------------------------------------------------------------------------

    try:
        epoch = int(global_step / trainingset.n_mbs)
        epochs = range(epoch, config.n_epochs)

        # Loop through epochs
        print("Starting training")

        for ep in epochs:
            epoch = ep
            print("Starting training epoch: {}".format(ep))
            # Initialize variables for over-all loss per epoch
            train_loss = 0

            # Load one minibatch at a time and perform a training step
            t_mb = Timer(verbose=True, name="Load Minibatch")
            mb_training = trainingset.batch_loader(rnd_gen=rnd_gen)

            #
            # Loop through minibatches
            #
            for mb_i, mb in enumerate(mb_training):
                sys.stdout.flush()
                # Print minibatch load time
                t_mb.print()

                # Abort if indicated by file
                check_kill_file(workspace)

                #
                # Calculate scores on validation set
                #
                if global_step % config.score_at == 0:
                    print("Starting scoring on validation set...")
                    evaluate_on_validation_set(validationset, global_step,
                                               sess, model,
                                               summary_writer_validation,
                                               val_loss_summary, val_loss,
                                               workspace)

                #
                # Perform weight updates and do plotting
                #
                if (mb_i % config.plot_at) == 0 and os.path.isfile(
                        workspace.get_plot_file()):
                    # Perform weight update, return summary_str and values for plotting
                    with Timer(verbose=True, name="Weight Update"):
                        train_summ, regpen_summ, _, cur_loss, cur_output, *plot_elements = sess.run(
                            [
                                train_summary, regpen_summary, update, loss,
                                model.output, *plot_elements_sym
                            ],
                            feed_dict={
                                model.X: mb['X'],
                                model.y_: mb['y']
                            })

                    # Add current summary values to tensorboard
                    summary_writer_train.add_summary(train_summ,
                                                     global_step=global_step)
                    summary_writer_train.add_summary(regpen_summ,
                                                     global_step=global_step)

                    # Re-associate returned tensorflow values to plotting keys
                    plot_dict = OrderedDict(
                        zip(list(model.get_plot_dict().keys()), plot_elements))

                    #
                    # Plot subplots in plot_dict
                    # Loop through each element in plotlist and pass it to the save_subplots function for plotting
                    # (adapt this to your needs for plotting)
                    #
                    with Timer(verbose=True, name="Plotting",
                               precision="msec"):
                        for plotlist_i, plotlist in enumerate(
                                model.get_plotsink()):
                            for frame in range(len(plot_dict[plotlist[0]])):
                                subplotlist = []
                                subfigtitles = []
                                subplotranges = []
                                n_cols = int(np.ceil(np.sqrt(len(plotlist))))

                                for col_i, col_i in enumerate(range(n_cols)):
                                    subfigtitles.append(
                                        plotlist[n_cols *
                                                 col_i:n_cols * col_i +
                                                 n_cols])
                                    subplotlist.append([
                                        plot_dict[p]
                                        [frame *
                                         (frame < len(plot_dict[p])), :]
                                        for p in plotlist[n_cols *
                                                          col_i:n_cols *
                                                          col_i + n_cols]
                                    ])
                                    subplotranges.append([
                                        plot_ranges.get(p, False)
                                        for p in plotlist[n_cols *
                                                          col_i:n_cols *
                                                          col_i + n_cols]
                                    ])

                                # remove rows/columns without images
                                subplotlist = [
                                    p for p in subplotlist if p != []
                                ]

                                plot_args = dict(
                                    images=subplotlist,
                                    filename=os.path.join(
                                        workspace.get_result_dir(),
                                        "plot{}_ep{}_mb{}_fr{}.png".format(
                                            plotlist_i, ep, mb_i, frame)),
                                    subfigtitles=subfigtitles,
                                    subplotranges=subplotranges)
                                plotter.set_plot_kwargs(plot_args)
                                plotter.plot()

                    # Plot outputs and cell states over frames if specified
                    if config.store_states and 'ConvLSTMLayer_h' in plot_dict:
                        convh = plot_dict['ConvLSTMLayer_h']
                        convrh = [c[0, :, :, 0] for c in convh]
                        convrh = [
                            convrh[:6], convrh[6:12], convrh[12:18],
                            convrh[18:24], convrh[24:]
                        ]
                        plot_args = dict(images=convrh,
                                         filename=os.path.join(
                                             workspace.get_result_dir(),
                                             "plot{}_ep{}_mb{}_h.png".format(
                                                 plotlist_i, ep, mb_i)))
                        plotter.set_plot_kwargs(plot_args)
                        plotter.plot()

                    if config.store_states and 'ConvLSTMLayer_c' in plot_dict:
                        convc = plot_dict['ConvLSTMLayer_c']
                        convrc = [c[0, :, :, 0] for c in convc]
                        convrc = [
                            convrc[:6], convrc[6:12], convrc[12:18],
                            convrc[18:24], convrc[24:]
                        ]
                        plot_args = dict(images=convrc,
                                         filename=os.path.join(
                                             workspace.get_result_dir(),
                                             "plot{}_ep{}_mb{}_c.png".format(
                                                 plotlist_i, ep, mb_i)))
                        plotter.set_plot_kwargs(plot_args)
                        plotter.plot()

                else:
                    #
                    # Perform weight update without plotting
                    #
                    with Timer(verbose=True, name="Weight Update"):
                        train_summ, regpen_summ, _, cur_loss = sess.run(
                            [train_summary, regpen_summary, update, loss],
                            feed_dict={
                                model.X: mb['X'],
                                model.y_: mb['y']
                            })

                    # Add current summary values to tensorboard
                    summary_writer_train.add_summary(train_summ,
                                                     global_step=global_step)
                    summary_writer_train.add_summary(regpen_summ,
                                                     global_step=global_step)

                # Add current loss to running average loss
                train_loss += cur_loss

                # Print some status info
                print("ep {} mb {} loss {} (avg. loss {})".format(
                    ep, mb_i, cur_loss, train_loss / (mb_i + 1)))

                # Reset timer
                t_mb = Timer(name="Load Minibatch")

                # Free the memory allocated for the minibatch data
                mb.clear()
                del mb

                global_step += 1

            #
            # Calculate scores on validation set
            #

            # Perform scoring on validation set
            print("Starting scoring on validation set...")
            evaluate_on_validation_set(validationset, global_step, sess, model,
                                       summary_writer_validation,
                                       val_loss_summary, val_loss, workspace)

            # Save the model
            tell.save_checkpoint(global_step=global_step)

            # Abort if indicated by file
            check_kill_file(workspace)

    except AbortRun:
        print("Detected kill file, aborting...")

    finally:
        #
        # If the program executed correctly or an error was raised, close the data readers and save the model and exit
        #
        trainingset.close()
        validationset.close()
        tell.close(save_checkpoint=True, global_step=global_step)
        plotter.close()
示例#3
0
def main(_):
    # ------------------------------------------------------------------------------------------------------------------
    # Setup training
    # ------------------------------------------------------------------------------------------------------------------

    # Initialize config, parses command line and reads specified file; also supports overriding of values from cmd
    config = Config()

    # Load datasets for trainingset
    with Timer(name="Loading Training Data"):
        # Make sure datareader is reproducible
        random_seed = config.get_value('random_seed', 12345)
        np.random.seed(
            random_seed)  # not threadsafe, use rnd_gen object where possible
        rnd_gen = np.random.RandomState(seed=random_seed)

        print("Loading training data...")
        trainingset = ShortLongDataset(n_timesteps=250,
                                       n_samples=3000,
                                       batchsize=config.batchsize,
                                       rnd_gen=rnd_gen)

        # Load datasets for validationset
        validationset = ShortLongDataset(n_timesteps=250,
                                         n_samples=300,
                                         batchsize=config.batchsize,
                                         rnd_gen=rnd_gen)

    # Initialize TeLL session
    tell = TeLLSession(config=config,
                       summaries=["train"],
                       model_params={"dataset": trainingset})

    # Get some members from the session for easier usage
    session = tell.tf_session
    summary_writer = tell.tf_summaries["train"]
    model = tell.model
    workspace, config = tell.workspace, tell.config

    # Loss function for trainingset
    print("Initializing loss calculation...")
    loss = tf.reduce_mean(
        tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
            model.y_, model.output,
            -tf.reduce_sum(model.y_ - 1) / tf.reduce_sum(model.y_)),
                       axis=[1]))
    train_summary = tf.summary.scalar("Training Loss",
                                      loss)  # add loss to tensorboard

    # Loss function for validationset
    val_loss = tf.reduce_mean(
        tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
            model.y_, model.output,
            -tf.reduce_sum(model.y_ - 1) / tf.reduce_sum(model.y_)),
                       axis=[1]))
    val_loss_summary = tf.summary.scalar(
        "Validation Loss", val_loss)  # add val_loss to tensorboard

    # Regularization
    reg_penalty = regularize(layers=model.get_layers(),
                             l1=config.l1,
                             l2=config.l2,
                             regularize_weights=True,
                             regularize_biases=True)
    regpen_summary = tf.summary.scalar(
        "Regularization Penalty",
        reg_penalty)  # add reg_penalty to tensorboard

    # Update step for weights
    update = update_step(loss + reg_penalty, config)

    # Initialize Tensorflow variables
    global_step = tell.initialize_tf_variables().global_step

    sys.stdout.flush()

    # ------------------------------------------------------------------------------------------------------------------
    # Start training
    # ------------------------------------------------------------------------------------------------------------------

    try:
        epoch = int(global_step / trainingset.n_mbs)
        epochs = range(epoch, config.n_epochs)

        #
        # Loop through epochs
        #
        print("Starting training")

        for ep in epochs:
            print("Starting training epoch: {}".format(ep))
            # Initialize variables for over-all loss per epoch
            train_loss = 0

            # Load one minibatch at a time and perform a training step
            t_mb = Timer(name="Load Minibatch")
            mb_training = trainingset.batch_loader(rnd_gen=rnd_gen)

            #
            # Loop through minibatches
            #
            for mb_i, mb in enumerate(mb_training):
                sys.stdout.flush()
                # Print minibatch load time
                t_mb.print()

                # Abort if indicated by file
                check_kill_file(workspace)

                #
                # Calculate scores on validation set
                #
                if global_step % config.score_at == 0:
                    print("Starting scoring on validation set...")
                    evaluate_on_validation_set(validationset, global_step,
                                               session, model, summary_writer,
                                               val_loss_summary, val_loss,
                                               workspace)

                #
                # Perform weight update
                #
                with Timer(name="Weight Update"):
                    train_summ, regpen_summ, _, cur_loss = session.run(
                        [train_summary, regpen_summary, update, loss],
                        feed_dict={
                            model.X: mb['X'],
                            model.y_: mb['y']
                        })

                # Add current summary values to tensorboard
                summary_writer.add_summary(train_summ, global_step=global_step)
                summary_writer.add_summary(regpen_summ,
                                           global_step=global_step)

                # Add current loss to running average loss
                train_loss += cur_loss

                # Print some status info
                print("ep {} mb {} loss {} (avg. loss {})".format(
                    ep, mb_i, cur_loss, train_loss / (mb_i + 1)))

                # Reset timer
                t_mb = Timer(name="Load Minibatch")

                # Free the memory allocated for the minibatch data
                mb.clear()
                del mb

                global_step += 1

            #
            # Calculate scores on validation set after training is done
            #

            # Perform scoring on validation set
            print("Starting scoring on validation set...")
            evaluate_on_validation_set(validationset, global_step, session,
                                       model, summary_writer, val_loss_summary,
                                       val_loss, workspace)

            tell.save_checkpoint(global_step=global_step)

            # Abort if indicated by file
            check_kill_file(workspace)

    except AbortRun:
        print("Detected kill file, aborting...")

    finally:
        tell.close(save_checkpoint=True, global_step=global_step)
def main(_):
    # ------------------------------------------------------------------------------------------------------------------
    # Setup training
    # ------------------------------------------------------------------------------------------------------------------
    
    # Initialize config, parses command line and reads specified file; also supports overriding of values from cmd
    config = Config()
    
    #
    # Prepare input data
    #
    
    # Make sure datareader is reproducible
    random_seed = config.get_value('random_seed', 12345)
    np.random.seed(random_seed)  # not threadsafe, use rnd_gen object where possible
    rnd_gen = np.random.RandomState(seed=random_seed)
    
    # Set datareaders
    n_timesteps = config.get_value('mnist_n_timesteps', 20)
    
    # Load datasets for trainingset
    with Timer(name="Loading Data"):
        readers = initialize_datareaders(config, required=("train", "val"))
    
    # Set Preprocessing
    trainingset = Normalize(readers["train"], apply_to=['X', 'y'])
    validationset = Normalize(readers["val"], apply_to=['X', 'y'])
    
    # Set minibatch loaders
    trainingset = DataLoader(trainingset, batchsize=2, batchsize_method='zeropad', verbose=False)
    validationset = DataLoader(validationset, batchsize=2, batchsize_method='zeropad', verbose=False)
    
    #
    # Initialize TeLL session
    #
    tell = TeLLSession(config=config, summaries=["train", "validation"], model_params={"dataset": trainingset})
    
    # Get some members from the session for easier usage
    sess = tell.tf_session
    summary_writer_train, summary_writer_validation = tell.tf_summaries["train"], tell.tf_summaries["validation"]
    model = tell.model
    workspace, config = tell.workspace, tell.config
    
    #
    # Define loss functions and update steps
    #
    print("Initializing loss calculation...")
    loss, _ = image_crossentropy(target=model.y_[:, 10:, :, :], pred=model.output[:, 10:, :, :, :],
                                 pixel_weights=model.pixel_weights[:, 10:, :, :], reduce_by='mean')
    train_summary = tf.summary.scalar("Training Loss", loss)  # create summary to add to tensorboard
    
    # Loss function for validationset
    val_loss = loss
    val_loss_summary = tf.summary.scalar("Validation Loss", val_loss)  # create summary to add to tensorboard
    
    # Regularization
    reg_penalty = regularize(layers=model.get_layers(), l1=config.l1, l2=config.l2,
                             regularize_weights=True, regularize_biases=True)
    regpen_summary = tf.summary.scalar("Regularization Penalty", reg_penalty)  # create summary to add to tensorboard
    
    # Update step for weights
    update = update_step(loss + reg_penalty, config)
    
    #
    # Initialize tensorflow variables (either initializes them from scratch or restores from checkpoint)
    #
    global_step = tell.initialize_tf_variables().global_step
    
    #
    # Set up plotting
    #  (store tensors we want to plot in a dictionary for easier tensor-evaluation)
    #
    # We want to plot input, output and target for the 1st sample, 1st frame, and 1st channel in subplot 1
    tensors_subplot1 = OrderedDict()
    tensors_subplot2 = OrderedDict()
    tensors_subplot3 = OrderedDict()
    for frame in range(n_timesteps):
        tensors_subplot1['input_{}'.format(frame)] = model.X[0, frame, :, :]
        tensors_subplot2['target_{}'.format(frame)] = model.y_[0, frame, :, :] - 1
        tensors_subplot3['network_output_{}'.format(frame)] = tf.argmax(model.output[0, frame, :, :, :], axis=-1) - 1
    # We also want to plot the cell states and hidden states for each frame (again of the 1st sample and 1st lstm unit)
    # in subplot 2 and 3
    tensors_subplot4 = OrderedDict()
    tensors_subplot5 = OrderedDict()
    for frame in range(len(model.lstm_layer.c)):
        tensors_subplot4['hiddenstate_{}'.format(frame)] = model.lstm_layer.h[frame][0, :, :, 0]
        tensors_subplot5['cellstate_{}'.format(frame)] = model.lstm_layer.c[frame][0, :, :, 0]
    # Create a list to store all symbolic tensors for plotting
    plotting_tensors = list(tensors_subplot1.values()) + list(tensors_subplot2.values()) + \
                       list(tensors_subplot3.values()) + list(tensors_subplot4.values()) + \
                       list(tensors_subplot5.values())
    
    #
    # Finalize graph
    #  This makes our tensorflow graph read-only and prevents further additions to the graph
    #
    sess.graph.finalize()
    if sess.graph.finalized:
        print("Graph is finalized!")
    else:
        raise ValueError("Could not finalize graph!")
    
    sys.stdout.flush()
    
    # ------------------------------------------------------------------------------------------------------------------
    # Start training
    # ------------------------------------------------------------------------------------------------------------------
    
    try:
        epoch = int(global_step / trainingset.n_mbs)
        epochs = range(epoch, config.n_epochs)
        
        # Loop through epochs
        print("Starting training")
        
        for ep in epochs:
            epoch = ep
            print("Starting training epoch: {}".format(ep))
            # Initialize variables for over-all loss per epoch
            train_loss = 0
            
            # Load one minibatch at a time and perform a training step
            t_mb = Timer(verbose=True, name="Load Minibatch")
            mb_training = trainingset.batch_loader(rnd_gen=rnd_gen)
            
            #
            # Loop through minibatches
            #
            for mb_i, mb in enumerate(mb_training):
                sys.stdout.flush()
                # Print minibatch load time
                t_mb.print()
                
                # Abort if indicated by file
                check_kill_file(workspace)
                
                #
                # Calculate scores on validation set
                #
                if global_step % config.score_at == 0:
                    print("Starting scoring on validation set...")
                    evaluate_on_validation_set(validationset, global_step, sess, model, summary_writer_validation,
                                               val_loss_summary, val_loss, workspace)
                
                #
                # Perform weight updates and do plotting
                #
                if (mb_i % config.plot_at) == 0 and os.path.isfile(workspace.get_plot_file()):
                    # Perform weight update, return summary values and values for plotting
                    with Timer(verbose=True, name="Weight Update"):
                        plotting_values = []
                        train_summ, regpen_summ, _, cur_loss, *plotting_values = sess.run(
                            [train_summary, regpen_summary, update, loss, *plotting_tensors],
                            feed_dict={model.X: mb['X'], model.y_: mb['y']})
                    
                    # Add current summary values to tensorboard
                    summary_writer_train.add_summary(train_summ, global_step=global_step)
                    summary_writer_train.add_summary(regpen_summ, global_step=global_step)
                    
                    # Create and save subplot 1 (input)
                    save_subplots(images=plotting_values[:len(tensors_subplot1)],
                                  subfigtitles=list(tensors_subplot1.keys()),
                                  subplotranges=[(0, 1)] * n_timesteps, colorbar=True, automatic_positioning=True,
                                  tight_layout=True,
                                  filename=os.path.join(workspace.get_result_dir(),
                                                        "input_ep{}_mb{}.png".format(ep, mb_i)))
                    del plotting_values[:len(tensors_subplot1)]
                    
                    # Create and save subplot 2 (target)
                    save_subplots(images=plotting_values[:len(tensors_subplot2)],
                                  subfigtitles=list(tensors_subplot2.keys()),
                                  subplotranges=[(0, 10) * n_timesteps], colorbar=True, automatic_positioning=True,
                                  tight_layout=True,
                                  filename=os.path.join(workspace.get_result_dir(),
                                                        "target_ep{}_mb{}.png".format(ep, mb_i)))
                    del plotting_values[:len(tensors_subplot2)]
                    
                    # Create and save subplot 3 (output)
                    save_subplots(images=plotting_values[:len(tensors_subplot3)],
                                  subfigtitles=list(tensors_subplot3.keys()),
                                  # subplotranges=[(0, 10)] * n_timesteps,
                                  colorbar=True, automatic_positioning=True,
                                  tight_layout=True,
                                  filename=os.path.join(workspace.get_result_dir(),
                                                        "output_ep{}_mb{}.png".format(ep, mb_i)))
                    del plotting_values[:len(tensors_subplot3)]
                    
                    # Create and save subplot 2 (hidden states, i.e. ConvLSTM outputs)
                    save_subplots(images=plotting_values[:len(tensors_subplot4)],
                                  subfigtitles=list(tensors_subplot4.keys()),
                                  title='ConvLSTM hidden states (outputs)', colorbar=True, automatic_positioning=True,
                                  tight_layout=True,
                                  filename=os.path.join(workspace.get_result_dir(),
                                                        "hidden_ep{}_mb{}.png".format(ep, mb_i)))
                    del plotting_values[:len(tensors_subplot4)]
                    
                    # Create and save subplot 3 (cell states)
                    save_subplots(images=plotting_values[:len(tensors_subplot5)],
                                  subfigtitles=list(tensors_subplot5.keys()),
                                  title='ConvLSTM cell states', colorbar=True, automatic_positioning=True,
                                  tight_layout=True,
                                  filename=os.path.join(workspace.get_result_dir(),
                                                        "cell_ep{}_mb{}.png".format(ep, mb_i)))
                    del plotting_values[:len(tensors_subplot5)]
                
                else:
                    #
                    # Perform weight update without plotting
                    #
                    with Timer(verbose=True, name="Weight Update"):
                        train_summ, regpen_summ, _, cur_loss = sess.run([
                            train_summary, regpen_summary, update, loss],
                            feed_dict={model.X: mb['X'], model.y_: mb['y']})
                    
                    # Add current summary values to tensorboard
                    summary_writer_train.add_summary(train_summ, global_step=global_step)
                    summary_writer_train.add_summary(regpen_summ, global_step=global_step)
                
                # Add current loss to running average loss
                train_loss += cur_loss
                
                # Print some status info
                print("ep {} mb {} loss {} (avg. loss {})".format(ep, mb_i, cur_loss, train_loss / (mb_i + 1)))
                
                # Reset timer
                t_mb = Timer(name="Load Minibatch")
                
                # Free the memory allocated for the minibatch data
                mb.clear()
                del mb
                
                global_step += 1
            
            #
            # Calculate scores on validation set
            #
            
            # Perform scoring on validation set
            print("Starting scoring on validation set...")
            evaluate_on_validation_set(validationset, global_step, sess, model, summary_writer_validation,
                                       val_loss_summary, val_loss, workspace)
            
            # Save the model
            tell.save_checkpoint(global_step=global_step)
            
            # Abort if indicated by file
            check_kill_file(workspace)
    
    except AbortRun:
        print("Detected kill file, aborting...")
    
    finally:
        #
        # If the program executed correctly or an error was raised, close the data readers and save the model and exit
        #
        trainingset.close()
        validationset.close()
        tell.close(save_checkpoint=True, global_step=global_step)
示例#5
0
                            padding='VALID',
                            stride=1)
aux_target_2 = tf.reduce_sum(rewards_placeholder, axis=1) - tf.cumsum(
    rewards_placeholder, axis=1)
aux_target_3 = tf.cumsum(rewards_placeholder, axis=1)
targets = tf.concat(
    [rewards_placeholder, aux_target_1, aux_target_2, aux_target_3], axis=-1)

return_prediction = rr_returns['predictions'][0, -1, 0]
true_return = tf.reduce_sum(targets[0, :, 0])
reward_prediction_error = tf.square(true_return - return_prediction)
auxiliary_losses = tf.reduce_mean(
    tf.square(targets[0, :, 1:] - rr_returns['predictions'][0, :, 1:]), axis=1)
# Add regularization penalty
rr_reg_penalty = regularize(layers=[lstm_layer, output_layer],
                            l1=1e-6,
                            regularize_weights=True,
                            regularize_biases=True)
total_loss = (reward_prediction_error +
              tf.reduce_mean(auxiliary_losses)) / 2 + rr_reg_penalty

trainables = tf.trainable_variables()
grads = tf.gradients(total_loss, trainables)
grads, _ = tf.clip_by_global_norm(grads, 0.5)

optimizer = tf.train.AdamOptimizer(learning_rate=1e-2)
rr_update = optimizer.apply_gradients(zip(grads, trainables))

#
# Set up Integrated Gradients for contribution analysis
#
n_intgrd_steps = 500