示例#1
0
    def train(self,
              data,
              n_epochs,
              alpha=0.001,
              dropout=0.0,
              batch_size=32,
              print_every=10,
              l2=None,
              augmentation_func=None,
              viz_every=10):
        """Trains the model, for n_epochs given a dictionary of data"""
        n_samples = len(data["X_train"])  # Num training samples
        n_batches = int(np.ceil(n_samples /
                                batch_size))  # Num batches per epoch
        print(
            "DEBUG - ", "using aug func"
            if augmentation_func is not None else "NOT using aug func")
        with tf.Session(graph=self.graph) as sess:
            self.initialize_vars(sess)
            t0 = time.time()

            try:
                self.update_status_file("training")
                for epoch in range(1, n_epochs + 1):
                    self.global_epoch += 1
                    print(
                        "=" * 70,
                        "\nEPOCH {}/{} (GLOBAL_EPOCH: {})        ELAPSED TIME: {}"
                        .format(epoch, n_epochs, self.global_epoch,
                                pretty_time(time.time() - t0)),
                        "\n" + ("=" * 70))

                    # Shuffle the data
                    data = self.shuffle_train_data(data)

                    # Iterate through each mini-batch
                    for i in range(n_batches):
                        X_batch, Y_batch = self.get_batch(
                            i,
                            X=data["X_train"],
                            Y=data["Y_train"],
                            batch_size=batch_size)
                        if augmentation_func is not None:
                            X_batch, Y_batch = augmentation_func(
                                X_batch, Y_batch)

                        # TRAIN
                        feed_dict = {
                            self.X: X_batch,
                            self.Y: Y_batch,
                            self.alpha: alpha,
                            self.is_training: True,
                            self.dropout: dropout
                        }
                        loss, _ = sess.run([self.loss, self.train_op],
                                           feed_dict=feed_dict)

                        # Print feedback every so often
                        if print_every is not None and (i +
                                                        1) % print_every == 0:
                            print("{} {: 5d} Batch_loss: {}".format(
                                pretty_time(time.time() - t0), i, loss))

                    # Save parameters after each epoch
                    self.save_snapshot_in_session(sess, self.snapshot_file)

                    # Evaluate on full train and validation sets after each epoch
                    train_iou, train_loss = self.evaluate_in_session(
                        data["X_train"][:1000], data["Y_train"][:1000], sess)
                    valid_iou, valid_loss = self.evaluate_in_session(
                        data["X_valid"], data["Y_valid"], sess)
                    self.update_evals_dict(train_iou=train_iou,
                                           train_loss=train_loss,
                                           valid_iou=valid_iou,
                                           valid_loss=valid_loss)
                    self.save_evals_dict()

                    # If its the best model so far, save best snapshot
                    is_best_so_far = self.evals[
                        self.best_evals_metric][-1] >= max(
                            self.evals[self.best_evals_metric])
                    if is_best_so_far:
                        self.save_snapshot_in_session(sess,
                                                      self.best_snapshot_file)

                    # Print evaluations (with asterix at end if it is best model so far)
                    s = "TR IOU: {: 3.3f} VA IOU: {: 3.3f} TR LOSS: {: 3.5f} VA LOSS: {: 3.5f} {}\n"
                    print(
                        s.format(train_iou, valid_iou, train_loss, valid_loss,
                                 "*" if is_best_so_far else ""))

                    # # TRAIN CURVES
                    train_curves(train=self.evals["train_iou"],
                                 valid=self.evals["valid_iou"],
                                 saveto=os.path.join(self.model_dir,
                                                     "iou.png"),
                                 title="IoU over time",
                                 ylab="IoU",
                                 legend_pos="lower right")
                    train_curves(train=self.evals["train_loss"],
                                 valid=self.evals["valid_loss"],
                                 saveto=os.path.join(self.model_dir,
                                                     "loss.png"),
                                 title="Loss over time",
                                 ylab="loss",
                                 legend_pos="upper right")

                    # VISUALIZE PREDICTIONS - once every so many epochs
                    if self.global_epoch % viz_every == 0:
                        self.visualise_semgmentations(data=data, session=sess)

                    str2file(str(max(self.evals[self.best_evals_metric])),
                             file=self.best_score_file)
                self.update_status_file("done")
                print("DONE in ", pretty_time(time.time() - t0))

            except KeyboardInterrupt as e:
                print("Keyboard Interupt detected")
                # TODO: Finish up gracefully. Maybe create recovery snapshots of model
                self.update_status_file("interupted")
                raise e
            except:
                self.update_status_file("crashed")
                raise
示例#2
0
 def update_status_file(self, status):
     str2file(status, file=self.train_status_file)