Exemplo n.º 1
0
    def train(self, image, epochs, enable_es=1):
        graph = tf.Graph()
        with tf.Session(graph=graph) as session:
            tf.set_random_seed(1234)

            self.__create_inputs()
            new_saver = self.__create_graph(self.meta_file)
            self.__create_loss_optimizer()

            # slim.model_analyzer.analyze_vars(tf.trainable_variables() , print_info=True)

            early_stopping = EarlyStopping(patience=30, min_delta=1e-1)

            tf.global_variables_initializer().run()

            new_saver.restore(session,self.latest_checkpoint)
            
            recons_loss = list()
            print('Starting optimization...')
            for cur_epoch in range(epochs + 1):


                dict_loss = self.__train_epoch(session,image)
                list_loss = list(dict_loss.values())

                if np.isnan(list_loss[0]):
                    print ('Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.')
                    sys.exit()

                if(cur_epoch % 20 == 0 or cur_epoch==0):
                    print('EPOCH: {} | dist: {} '.format(cur_epoch, list_loss[0]))
                    
                recons_loss.append(list_loss[0])
                #Early stopping
                if(cur_epoch>50 and enable_es==1 and early_stopping.stop(list_loss[0])):
                    print('Early Stopping!')
                    print('EPOCH: {} | dist: {} '.format(cur_epoch, list_loss[0]))
                    break


            z_infer =  session.run(self.z)
            x_recons = session.run(self.x_recons)

        return z_infer, x_recons, recons_loss
    def fit(self, X, y=None):
        print('\nProcessing data...')

        self.data_train = data_utils.process_data(X, y, test_size=0)
        if self.config.plot:
            self.data_plot = self.data_train

        self.config.num_batches = self.data_train.num_batches(
            self.config.batch_size)

        if not self.config.isBuilt:
            self.config.restore = True
            self.build_model(self.data_train.height, self.data_train.width,
                             self.data_train.num_channels)
        else:
            assert (self.config.height == self.data_train.height) and (self.config.width == self.data_train.width) and \
                   (self.config.num_channels == self.data_train.num_channels), \
                    'Wrong dimension of data. Expected shape {}, and got {}'.format((self.config.height,self.config.width, \
                                                                                     self.config.num_channels), \
                                                                                    (self.data_train.height,
                                                                                     self.data_train.width, \
                                                                                     self.data_train.num_channels) \
                                                                                    )
        ''' 
         -------------------------------------------------------------------------------
                                        TRAIN THE MODEL
        ------------------------------------------------------------------------------------- 
        '''
        print('\nTraining a model...')
        with tf.Session(graph=self.graph) as session:
            tf.set_random_seed(self.config.seeds)
            self.session = session
            logger = Logger(self.session, self.config.log_dir)
            saver = tf.train.Saver()

            early_stopper = EarlyStopping(name='total loss',
                                          decay_fn=self.decay_fn)

            if (self.config.restore and self.load(self.session, saver)):
                load_config = file_utils.load_args(self.config.model_name,
                                                   self.config.config_dir)
                self.config.update(load_config)

                num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(
                    self.session)
                print('EPOCHS trained: ', num_epochs_trained)
            else:
                print('Initializing Variables ...')
                tf.global_variables_initializer().run()

            for cur_epoch in range(
                    self.model_graph.cur_epoch_tensor.eval(self.session),
                    self.config.epochs + 1, 1):
                print('EPOCH: ', cur_epoch)
                self.current_epoch = cur_epoch

                losses_tr = self._train(self.data_train, self.session, logger)

                if np.isnan(losses_tr[0]):
                    print(
                        'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.'
                    )
                    for lname, lval in zip(self.model_graph.losses, losses_tr):
                        print(lname, lval)
                    sys.exit()

                train_msg = 'TRAIN: \n'
                for lname, lval in zip(self.model_graph.losses, losses_tr):
                    train_msg += str(lname) + ': ' + str(lval) + ' | '

                print(train_msg)
                print()

                if (cur_epoch
                        == 1) or ((cur_epoch % self.config.save_epoch == 0) and
                                  (cur_epoch != 0)):
                    gc.collect()
                    self.save(
                        self.session, saver,
                        self.model_graph.global_step_tensor.eval(self.session))
                    if self.config.plot:
                        self.plot_latent(cur_epoch)

                self.session.run(self.model_graph.increment_cur_epoch_tensor)

                # Early stopping
                if (self.config.early_stopping
                        and early_stopper.stop(losses_tr[0])):
                    print('Early Stopping!')
                    break

                if cur_epoch % self.config.colab_save == 0:
                    if self.config.colab:
                        self.push_colab()

            self.save(self.session, saver,
                      self.model_graph.global_step_tensor.eval(self.session))
            if self.config.plot:
                self.plot_latent(cur_epoch)

            if self.config.colab:
                self.push_colab()

        return
    def train(self, data_train, data_valid, enable_es=1):

        with tf.Session(graph=self.graph) as session:
            tf.set_random_seed(1234)

            logger = Logger(session, self.summary_dir)
            # here you initialize the tensorflow saver that will be used in saving the checkpoints.
            # max_to_keep: defaults to keeping the 5 most recent checkpoints of your model
            saver = tf.train.Saver()
            self.session = session
            early_stopping = EarlyStopping(name='total loss',
                                           decay_fn=self.decay_fn)

            if (self.restore and self.load(session, saver)):
                num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(
                    session)
                print('EPOCHS trained: ', num_epochs_trained)
            else:
                print('Initizalizing Variables ...')
                tf.global_variables_initializer().run()

            if (self.model_graph.cur_epoch_tensor.eval(session) == self.epochs
                ):
                return

            for cur_epoch in range(
                    self.model_graph.cur_epoch_tensor.eval(session),
                    self.epochs + 1, 1):

                print('EPOCH: ', cur_epoch)
                self.current_epoch = cur_epoch

                loss_tr, recons_tr, L2_loss = self.train_epoch(
                    session, logger, data_train)
                if da.isnan(loss_tr):
                    print(
                        'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.'
                    )
                    print('Recons: ', recons_tr)
                    sys.exit()

                loss_val, recons_val = self.valid_epoch(
                    session, logger, data_valid)

                print('TRAIN | AE Loss: ', loss_tr, ' | Recons: ', recons_tr,
                      ' | L2_loss: ', L2_loss)
                print('VALID | AE Loss: ', loss_val, ' | Recons: ', recons_val)

                if (cur_epoch == 1) or ((cur_epoch % const.SAVE_EPOCH == 0) and
                                        ((cur_epoch != 0))):
                    self.save(
                        session, saver,
                        self.model_graph.global_step_tensor.eval(session))
                    if self.plot:
                        self.generate_samples(data_train, session, cur_epoch)

                    if self.clustering:
                        self.generate_clusters(logger, cur_epoch, data_train,
                                               data_valid)

                session.run(self.model_graph.increment_cur_epoch_tensor)

                #Early stopping
                if (enable_es == 1 and early_stopping.stop(loss_val)):
                    print('Early Stopping!')
                    break

                if cur_epoch % 50 == 0:
                    if self.colab:
                        self.push_colab()

            self.save(session, saver,
                      self.model_graph.global_step_tensor.eval(session))
            if self.plot:
                self.generate_samples(data_train, session, cur_epoch)

            if self.clustering:
                self.generate_clusters(logger, cur_epoch, data_train,
                                       data_valid)

            if self.colab:
                self.push_colab()
        return
Exemplo n.º 4
0
    def train(self, data_train, data_valid, enable_es=1):

        with tf.Session(graph=self.graph) as session:
            tf.set_random_seed(1234)

            logger = Logger(session, self.summary_dir)
            # here you initialize the tensorflow saver that will be used in saving the checkpoints.
            # max_to_keep: defaults to keeping the 5 most recent checkpoints of your model
            saver = tf.train.Saver()
            early_stopping = EarlyStopping()

            if (self.restore == 1 and self.load(session, saver)):
                num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(
                    session)
                print('EPOCHS trained: ', num_epochs_trained)
            else:
                print('Initizalizing Variables ...')
                tf.global_variables_initializer().run()

            if (self.model_graph.cur_epoch_tensor.eval(session) == self.epochs
                ):
                return

            for cur_epoch in range(
                    self.model_graph.cur_epoch_tensor.eval(session),
                    self.epochs + 1, 1):

                print('EPOCH: ', cur_epoch)
                self.current_epoch = cur_epoch
                # beta=utils.sigmoid(cur_epoch- 50)
                beta = 1.
                losses, recons, cond_prior, KL_w, y_prior, L2_loss = self.train_epoch(
                    session, logger, data_train, beta=beta)
                train_string = 'TRAIN | Loss: ' + str(losses) + \
                            ' | Recons: ' + str(recons) + \
                            ' | CP: ' + str(cond_prior) + \
                            ' | KL_w: ' + str(KL_w) + \
                            ' | KL_y: ' + str(y_prior) + \
                            ' | L2_loss: '+  str(L2_loss)
                # train_string = colored(train_string, 'red', attrs=['reverse', 'blink'])
                train_string = colored(train_string, 'red')
                if np.isnan(losses):
                    print(
                        'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.'
                    )
                    print('Recons: ', recons)
                    print('CP: ', cond_prior)
                    print('KL_w: ', KL_w)
                    print('KL_y: ', y_prior)
                    sys.exit()

                loss_val, recons, cond_prior, KL_w, y_prior, L2_loss = self.valid_epoch(
                    session, logger, data_valid, beta=beta)
                valid_string = 'VALID | Loss: ' + str(loss_val) + \
                            ' | Recons: ' + str(recons) + \
                            ' | CP: ' + str(cond_prior) + \
                            ' | KL_w: ' + str(KL_w) + \
                            ' | KL_y: ' + str(y_prior) + \
                            ' | L2_loss: '+  str(L2_loss)

                print(train_string)
                print(valid_string)

                if (cur_epoch > 0 and cur_epoch % 10 == 0):
                    self.save(
                        session, saver,
                        self.model_graph.global_step_tensor.eval(session))

                session.run(self.model_graph.increment_cur_epoch_tensor)

                #Early stopping
                if (enable_es == 1 and early_stopping.stop(loss_val)):
                    print('Early Stopping!')
                    break

            self.save(session, saver,
                      self.model_graph.global_step_tensor.eval(session))

        return
    def fit(self, dataset):
        assert str(dataset.__class__).split('.')[0].replace("<class '", '') + '.' + str(dataset.__class__).split('.')[1] \
               == "tensorflow_datasets.image", 'The dataset type is not image tensorflow_datasets'

        self.data_train = dataset.as_dataset(split=tfds.Split.TRAIN, shuffle_files=True, batch_size=self.config.batch_size)
        try:
            self.data_test = dataset.as_dataset(split=tfds.Split.TEST, shuffle_files=True, batch_size=self.config.batch_size)
        except:
            self.data_test = dataset.as_dataset(split=tfds.Split.TRAIN, shuffle_files=True, batch_size=self.config.batch_size)
            
        width = dataset.info.features['image'].shape[0]
        height = dataset.info.features['image'].shape[1]
        num_channels  = dataset.info.features['image'].shape[2]

        self.config.ntrain_batches = dataset.info.splits['train'].num_examples // self.config.batch_size
        self.config.ntest_batches = dataset.info.splits['test'].num_examples // self.config.batch_size

        if not self.config.isBuilt:
            self.config.restore=True
            self.build_model(height, width, num_channels)
        else:
            assert (self.config.height == height) and (self.config.width == width) and \
                   (num_channels == num_channels), \
                'Wrong dimension of data. Expected shape {}, and got {}'.\
                format((self.config.height, self.config.width, self.config.num_channels),  (height, width, num_channels))

        '''  
        -------------------------------------------------------------------------------
                                        TRAIN THE MODEL
        ------------------------------------------------------------------------------------- 
        '''
        print('\nTraining a model...')
        with tf.Session(graph=self.graph) as session:
            tf.set_random_seed(self.config.seeds)
            self.session = session
            logger = Logger(self.session, self.config.log_dir)
            self.saver = tf.train.Saver()
            early_stopper = EarlyStopping(name='total loss', decay_fn=self.decay_fn)

            if(self.config.restore and self.load(self.session, self.saver) ):
                load_config = file_utils.load_args(self.config.model_name, self.config.config_dir,
                                                   ['latent_mean', 'latent_std', 'samples', 'y_uniqs'])
                self.config.update(load_config)

                num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(self.session)
                print('EPOCHS trained: ', num_epochs_trained)
            else:
                print('Initializing Variables ...')
                tf.global_variables_initializer().run()

            if self.config.plot:
                if self.config.y_uniqs is None:
                    print('\nFinding the unique categories...')
                    y_uniqs = list()
                    iterator = self.data_train.make_one_shot_iterator()
                    for t in tqdm(range(self.config.ntrain_batches)):
                        batch = session.run(iterator.get_next())
                        y, _ = tf.unique(batch[self.config.y_index])
                        y_uniqs += y.eval().tolist()

                    self.config.y_uniqs = np.unique(y_uniqs)

                if self.config.samples is None:
                    y_uniqs = self.config.y_uniqs[:10]
                    y_uniqs = np.array(list(itertools.repeat(y_uniqs, 10))).flatten()[:10]

                    print('\nSampling from the unique categories...')
                    samples = dict(zip(y_uniqs, itertools.repeat(list(), len(y_uniqs))))
                    iterator = self.data_train.make_one_shot_iterator()
                    for t in tqdm(range(self.config.ntrain_batches)):
                        batch = session.run(iterator.get_next())
                        for yi in y_uniqs:
                            if len(samples[yi]) <= 10:
                                samples[yi] = samples[yi] + da.from_array(
                                    tf.boolean_mask(mask=batch[self.config.y_index]==yi, tensor=batch['image']).eval(),
                                                    chunks=10).compute().tolist()
                            samples[yi] = samples[yi][:10]
                    self.config.samples = da.from_array(da.vstack(samples.values()), chunks=10).compute()

            if not self.config.isTrained:
                for cur_epoch in range(self.model_graph.cur_epoch_tensor.eval(self.session), self.config.epochs+1, 1):
                    print('EPOCH: ', cur_epoch)
                    self.current_epoch = cur_epoch

                    losses_tr = self._train(self.data_train, self.session, logger, self.config.ntrain_batches)
                    if np.isnan(losses_tr[0]):
                        print('Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.')
                        for lname, lval in zip(self.model_graph.losses, losses_tr):
                            print(lname, lval)
                        sys.exit()

                    losses_test = self._test(self.data_test, self.session, logger, self.config.ntest_batches)
                    train_msg = 'TRAIN: \n'
                    for lname, lval in zip(self.model_graph.losses, losses_tr):
                        train_msg += str(lname) + ': ' + str(lval) + ' | '

                    eval_msg = 'TEST: \n'
                    for lname, lval in zip(self.model_graph.losses, losses_test):
                        eval_msg += str(lname) + ': ' + str(lval) + ' | '

                    print(train_msg)
                    print(eval_msg)
                    print()

                    if (cur_epoch == 1) or ((cur_epoch % self.config.save_epoch == 0) and (cur_epoch != 0)):
                        self.save_model()
                        if self.config.plot:
                            self.plot_latent(cur_epoch)
                            self.plot_reconst(cur_epoch)

                    self.session.run(self.model_graph.increment_cur_epoch_tensor)

                    # Early stopping
                    if (self.config.early_stopping and early_stopper.stop(losses_test[0])):
                        print('Early Stopping!')
                        break

                    if cur_epoch % self.config.colab_save == 0:
                        if self.config.colab:
                            self.push_colab()

                self.config.isTrained = True
                self.save_model()

                if self.config.plot:
                    self.plot_latent(cur_epoch)
                    self.plot_reconst(cur_epoch)

            if self.config.colab:
                self.push_colab()
Exemplo n.º 6
0
def sent_clf(dataset, config, opts, transfer=False):
    from logger.experiment import Experiment

    opts.name = config["name"]
    X_train, y_train, X_val, y_val = dataset
    vocab = None
    if transfer:
        opts.transfer = config["pretrained_lm"]
        checkpoint = load_checkpoint(opts.transfer)
        config["vocab"].update(checkpoint["config"]["vocab"])
        dict_pattern_rename(checkpoint["config"]["model"],
                            {"rnn_": "bottom_rnn_"})
        config["model"].update(checkpoint["config"]["model"])
        vocab = checkpoint["vocab"]

    ####################################################################
    # Load Preprocessed Datasets
    ####################################################################
    if config["preprocessor"] == "twitter":
        preprocessor = twitter_preprocessor()
    else:
        preprocessor = None

    print("Building training dataset...")
    train_set = ClfDataset(X_train,
                           y_train,
                           vocab=vocab,
                           preprocess=preprocessor,
                           vocab_size=config["vocab"]["size"],
                           seq_len=config["data"]["seq_len"])

    print("Building validation dataset...")
    val_set = ClfDataset(X_val,
                         y_val,
                         seq_len=train_set.seq_len,
                         preprocess=preprocessor,
                         vocab=train_set.vocab)

    src_lengths = [len(x) for x in train_set.data]
    val_lengths = [len(x) for x in val_set.data]

    # select sampler & dataloader
    train_sampler = BucketBatchSampler(src_lengths, config["batch_size"], True)
    val_sampler = SortedSampler(val_lengths)
    val_sampler_train = SortedSampler(src_lengths)

    train_loader = DataLoader(train_set,
                              batch_sampler=train_sampler,
                              num_workers=opts.cores,
                              collate_fn=ClfCollate())
    val_loader = DataLoader(val_set,
                            sampler=val_sampler,
                            batch_size=config["batch_size"],
                            num_workers=opts.cores,
                            collate_fn=ClfCollate())
    val_loader_train_dataset = DataLoader(train_set,
                                          sampler=val_sampler_train,
                                          batch_size=config["batch_size"],
                                          num_workers=opts.cores,
                                          collate_fn=ClfCollate())
    ####################################################################
    # Model
    ####################################################################
    ntokens = len(train_set.vocab)
    model = Classifier(ntokens, len(set(train_set.labels)), **config["model"])
    model.to(opts.device)

    clf_criterion = nn.CrossEntropyLoss()
    lm_criterion = nn.CrossEntropyLoss(ignore_index=0)

    embed_parameters = filter(lambda p: p.requires_grad,
                              model.embed.parameters())
    bottom_parameters = filter(
        lambda p: p.requires_grad,
        chain(model.bottom_rnn.parameters(), model.vocab.parameters()))
    if config["model"]["has_att"]:
        top_parameters = filter(
            lambda p: p.requires_grad,
            chain(model.top_rnn.parameters(), model.attention.parameters(),
                  model.classes.parameters()))
    else:
        top_parameters = filter(
            lambda p: p.requires_grad,
            chain(model.top_rnn.parameters(), model.classes.parameters()))

    embed_optimizer = optim.ASGD(embed_parameters, lr=0.0001)
    rnn_optimizer = optim.ASGD(bottom_parameters)
    top_optimizer = Adam(top_parameters, lr=config["top_lr"])
    ####################################################################
    # Training Pipeline
    ####################################################################

    # Trainer: responsible for managing the training process
    trainer = SentClfTrainer(model,
                             train_loader,
                             val_loader, (lm_criterion, clf_criterion),
                             [embed_optimizer, rnn_optimizer, top_optimizer],
                             config,
                             opts.device,
                             valid_loader_train_set=val_loader_train_dataset,
                             unfreeze_embed=config["unfreeze_embed"],
                             unfreeze_rnn=config["unfreeze_rnn"])

    ####################################################################
    # Experiment: logging and visualizing the training process
    ####################################################################

    exp = Experiment(opts.name,
                     config,
                     src_dirs=opts.source,
                     output_dir=EXP_DIR)
    exp.add_metric("ep_loss_lm", "line", "epoch loss lm", ["TRAIN", "VAL"])
    exp.add_metric("ep_loss_cls", "line", "epoch loss class", ["TRAIN", "VAL"])
    exp.add_metric("ep_f1", "line", "epoch f1", ["TRAIN", "VAL"])
    exp.add_metric("ep_acc", "line", "epoch accuracy", ["TRAIN", "VAL"])

    exp.add_value("epoch", title="epoch summary")
    exp.add_value("progress", title="training progress")

    ####################################################################
    # Resume Training from a previous checkpoint
    ####################################################################
    if transfer:
        print("Transferring Encoder weights ...")
        dict_pattern_rename(checkpoint["model"], {
            "encoder": "bottom_rnn",
            "decoder": "vocab"
        })
        load_state_dict_subset(model, checkpoint["model"])
    print(model)

    ####################################################################
    # Training Loop
    ####################################################################
    best_loss = None
    early_stopping = EarlyStopping("min", config["patience"])

    for epoch in range(0, config["epochs"]):

        train_loss = trainer.train_epoch()
        val_loss, y, y_pred = trainer.eval_epoch(val_set=True)
        _, y_train, y_pred_train = trainer.eval_epoch(train_set=True)
        exp.update_metric("ep_loss_lm", train_loss[0], "TRAIN")
        exp.update_metric("ep_loss_lm", val_loss[0], "VAL")

        exp.update_metric("ep_loss_cls", train_loss[1], "TRAIN")
        exp.update_metric("ep_loss_cls", val_loss[1], "VAL")

        exp.update_metric("ep_f1", f1_macro(y_train, y_pred_train), "TRAIN")
        exp.update_metric("ep_f1", f1_macro(y, y_pred), "VAL")

        exp.update_metric("ep_acc", acc(y_train, y_pred_train), "TRAIN")
        exp.update_metric("ep_acc", acc(y, y_pred), "VAL")

        print()
        epoch_log = exp.log_metrics(
            ["ep_loss_lm", "ep_loss_cls", "ep_f1", "ep_acc"])
        print(epoch_log)
        exp.update_value("epoch", epoch_log)

        # Save the model if the val loss is the best we've seen so far.
        if not best_loss or val_loss[1] < best_loss:
            best_loss = val_loss[1]
            trainer.best_acc = acc(y, y_pred)
            trainer.best_f1 = f1_macro(y, y_pred)
            trainer.checkpoint(name=opts.name, timestamp=True)

        if early_stopping.stop(val_loss[1]):
            print("Early Stopping (according to classification loss)....")
            break

        print("\n" * 2)

    return best_loss, trainer.best_acc, trainer.best_f1
Exemplo n.º 7
0
    def fit(self, X, y=None):
        print('\nProcessing data...')
        self.data_train, self.data_eval = utils.process_data(X, y)
        self.config['num_batches'] = self.data_train.num_batches(
            self.config.batch_size)

        if not self.isBuild:
            self.config.restore = True
            self.build_model(self.data_train.height, self.data_train.width,
                             self.data_train.num_channels)
        else:
            assert (self.config.height == self.data_train.height) and (self.config.width == self.data_train.width) and \
                   (self.config.num_channels == self.data_train.num_channels), \
                    'Wrong dimension of data. Expected shape {}, and got {}'.format((self.config.height,self.config.width, \
                                                                                     self.config.num_channels), \
                                                                                    (self.data_train.height,
                                                                                     self.data_train.width, \
                                                                                     self.data_train.num_channels) \
                                                                                    )
        '''  -------------------------------------------------------------------------------
                                        TRAIN THE MODEL
        ------------------------------------------------------------------------------------- '''
        print('\nTraining a model...')

        with tf.Session(graph=self.graph) as session:
            tf.set_random_seed(222222)
            self.session = session
            logger = Logger(self.session, self.config.summary_dir)
            saver = tf.train.Saver()

            early_stopper = EarlyStopping(name='total loss',
                                          decay_fn=self.decay_fn)

            if (self.config.restore and self.load(self.session, saver)):
                num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(
                    self.session)
                print('EPOCHS trained: ', num_epochs_trained)
            else:
                print('Initializing Variables ...')
                tf.global_variables_initializer().run()

            for cur_epoch in range(
                    self.model_graph.cur_epoch_tensor.eval(self.session),
                    self.config.epochs + 1, 1):

                print('EPOCH: ', cur_epoch)
                self.current_epoch = cur_epoch

                losses_tr = self._train(self.data_train, self.session, logger)

                if np.isnan(losses_tr[0]):
                    print(
                        'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.'
                    )
                    for lname, lval in zip(self.model_graph.losses, losses_tr):
                        print(lname, lval)
                    sys.exit()

                losses_eval = self._evaluate(self.data_eval, self.session,
                                             logger)

                train_msg = 'TRAIN: \n'
                for lname, lval in zip(self.model_graph.losses, losses_tr):
                    train_msg += str(lname) + ': ' + str(lval) + ' | '

                eval_msg = 'EVALUATE: \n'
                for lname, lval in zip(self.model_graph.losses, losses_eval):
                    eval_msg += str(lname) + ': ' + str(lval) + ' | '

                print(train_msg)
                print(eval_msg)
                print()

                if (cur_epoch == 1) or ((cur_epoch % const.SAVE_EPOCH == 0) and
                                        (cur_epoch != 0)):
                    self.save(
                        self.session, saver,
                        self.model_graph.global_step_tensor.eval(self.session))
                    if self.config.plot:
                        self.reconst_samples_from_data(self.data_train,
                                                       self.session, cur_epoch)

                self.session.run(self.model_graph.increment_cur_epoch_tensor)

                # Early stopping
                if (self.config.early_stopping
                        and early_stopper.stop(losses_eval[0])):
                    print('Early Stopping!')
                    break

                if cur_epoch % const.COLAB_SAVE == 0:
                    if self.config.colab:
                        self.push_colab()

            self.save(self.session, saver,
                      self.model_graph.global_step_tensor.eval(self.session))
            if self.config.plot:
                self.reconst_samples_from_data(self.data_train, self.session,
                                               cur_epoch)

            if self.config.colab:
                self.push_colab()
            z = self.encode(self.data_train.x)
            self.config.latent_max = z.max().compute()
            self.config.latent_min = z.min().compute()
            self.config.latent_std = z.std().compute()
            del z
        return
Exemplo n.º 8
0
def train(classifier, config, train_batcher, valid_batcher, test_batcher):
    """
    开启session, 真正的执行训练步骤。
    每训练完一个epoch之后使用验证集数据验证模型,并写入summary。
    指定的epoch数(配置文件的train_epoch项)训练完后,使用测试集数据测试模型。

    :param classifier: 分类器
    :param config: 配置文件
    :param train_batcher: 训练集数据的batch生成器。
                            batch的shape:[config.batch_size, config.max_len]
    :param valid_batcher: 验证集数据的batch生成器。
    :param test_batcher: 测试集数据的batch生成器。
    :return: 无。模型的训练进度保存在路径:config.log_root/config.model_name/train
    """
    # 创建eval_config,与config唯一不同是 dropout设置为1.0;验证和测试模型时使用eval_config
    eval_config = copy.deepcopy(config)
    eval_config.keep_prob = 1.0
    # 定义模型训练进度和summary的存放地址 train_dir
    train_dir = os.path.join(config.log_root, config.model_name, "train")
    # 定义训练集和验证集的summary_writer
    summary_writer_train = tf.summary.FileWriter(
        os.path.join(train_dir, "summaries", "summaries_train"))
    summary_writer_valid = tf.summary.FileWriter(
        os.path.join(train_dir, "summaries", "summaries_valid"))
    saver = tf.train.Saver(max_to_keep=3)
    early_stop = EarlyStopping(config.patience,
                               mode='min')  # 设置early_stop. 监控元素为 验证集的loss
    # 配置参数:内存自增长
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    # 创建MonitoredSession
    with tf.train.MonitoredTrainingSession(
            is_chief=True,
            checkpoint_dir=train_dir,
            scaffold=None,  # ?
            hooks=None,  # ?
            chief_only_hooks=None,  # ?
            save_checkpoint_secs=None,  # config.
            save_summaries_steps=None,
            save_summaries_secs=None,
            config=config_proto,
    ) as sess:
        # 计算需要运行的step数
        if config.train_epoch:
            step_num = int(config.train_epoch * train_batcher.data_size /
                           train_batcher.batch_size)
        else:
            step_num = config.train_step
        start = time.time()
        for step in range(step_num):
            if sess.should_stop():  # 可能有未知情况需要停止
                break
            # run train data one epoch
            t0 = time.time()
            global_step = classifier.global_step.eval(sess)
            print("\nglobal_step : {}".format(global_step))
            return_dict = run_some_steps(classifier,
                                         sess,
                                         config,
                                         run_train_step,
                                         train_batcher,
                                         summary_writer_train,
                                         step_num=config.once_train_step)
            print("Train time: %.3f s" % (time.time() - t0))
            print("Train accuracy: %.3f %%, loss: %.4f" %
                  (return_dict['acc'] * 100, return_dict['loss']))
            # evaluate valid data
            tf.logging.set_verbosity(tf.logging.ERROR)
            return_dict = run_one_epoch(classifier, sess, config,
                                        run_eval_step, valid_batcher,
                                        summary_writer_valid)
            tf.logging.set_verbosity(tf.logging.WARN)
            print("Valid accuracy: %.3f %%, loss: %.4f" %
                  (return_dict['acc'] * 100, return_dict['loss']))
            # early stopping
            if early_stop.add_monitor(return_dict['loss']):
                # 保存
                print('发现更优的模型参数,进行保存..')
                # 这个奇葩的写法,详见:https://github.com/tensorflow/tensorflow/issues/8425
                saver.save(sess._sess._sess._sess._sess,
                           os.path.join(train_dir, 'model.ckpt'),
                           global_step=global_step)
            if early_stop.stop():
                print('触发early stopping! 前 {} 个 monitor 为 {}'.format(
                    early_stop.patience, early_stop.pre_monitor))
                break
        print("Training finished, time consumed : %.3f s" %
              (time.time() - start))

        # evaluate test data
        print("\nStart evaluating:")
        tf.logging.set_verbosity(tf.logging.ERROR)
        return_dict = run_one_epoch(classifier, sess, config, run_eval_step,
                                    test_batcher)
        tf.logging.set_verbosity(tf.logging.WARN)
        print("Test  accuracy: %.3f %%, loss: %.4f" %
              (return_dict['acc'] * 100, return_dict['loss']))
        # 计算评价指标
        metrics_model(return_dict['real_label'], return_dict['predict'])
Exemplo n.º 9
0
    def train(self, data_train, data_valid, enable_es=1):

        with tf.Session(graph=self.graph) as session:
            tf.set_random_seed(1234)

            logger = Logger(session, self.summary_dir)
            # here you initialize the tensorflow saver that will be used in saving the checkpoints.
            # max_to_keep: defaults to keeping the 5 most recent checkpoints of your model
            saver = tf.train.Saver()
            early_stopping = EarlyStopping()

            if (self.restore == 1 and self.load(session, saver)):
                num_epochs_trained = self.vae_graph.cur_epoch_tensor.eval(
                    session)
                print('EPOCHS trained: ', num_epochs_trained)
            else:
                print('Initizalizing Variables ...')
                tf.global_variables_initializer().run()

            if (self.vae_graph.cur_epoch_tensor.eval(session) == self.epochs):
                return

            for cur_epoch in range(
                    self.vae_graph.cur_epoch_tensor.eval(session),
                    self.epochs + 1, 1):

                print('EPOCH: ', cur_epoch)
                self.current_epoch = cur_epoch
                # beta=utils.sigmoid(cur_epoch- 50)
                beta = 1.
                loss_tr, recons_tr, cond_prior_tr, L2_loss = self.train_epoch(
                    session, logger, data_train, beta=beta)
                if np.isnan(loss_tr):
                    print(
                        'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.'
                    )
                    print('Recons: ', recons_tr)
                    print('KL: ', cond_prior_tr)
                    sys.exit()

                loss_val, recons_val, cond_prior_val = self.valid_epoch(
                    session, logger, data_valid, beta=beta)

                print('TRAIN | Loss: ', loss_tr, ' | Recons: ', recons_tr,
                      ' | KL: ', cond_prior_tr, ' | L2_loss: ', L2_loss)
                print('VALID | Loss: ', loss_val, ' | Recons: ', recons_val,
                      ' | KL: ', cond_prior_val)

                if (cur_epoch > 0 and cur_epoch % 10 == 0):
                    self.save(session, saver,
                              self.vae_graph.global_step_tensor.eval(session))
                    z_matrix = self.vae_graph.get_z_matrix(
                        session, data_valid.random_batch(self.batch_size))
                    np.savez(self.z_file, z_matrix)

                session.run(self.vae_graph.increment_cur_epoch_tensor)

                #Early stopping
                if (enable_es == 1 and early_stopping.stop(loss_val)):
                    print('Early Stopping!')
                    break

            self.save(session, saver,
                      self.vae_graph.global_step_tensor.eval(session))
            z_matrix = self.vae_graph.get_z_matrix(
                session, data_valid.random_batch(self.batch_size))
            np.savez(self.z_file, z_matrix)
        return