Beispiel #1
0
def main(args):
    start_time = time.time()
    model = utils.build_model(args)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.milestone,
                                                     gamma=0.1,
                                                     last_epoch=-1)
    train_loader, test_loader = datasets.dataloader_builder(args)

    train_loss = np.zeros(args.epochs)
    test_loss, test_accu = np.zeros(args.epochs), np.zeros(args.epochs)
    print('\n\r\t#### Start Training ####')
    for epoch in range(args.epochs):
        train_loss[epoch] = trainer(model, train_loader, optimizer, criterion)
        test_loss[epoch], test_accu[epoch] = tester(model, test_loader)
        scheduler.step()
        # print(scheduler.get_lr()[0])
        print(
            '| Epoch: {0:3d} | Training Loss: {1:.6f} | Test Accuracy: {2:.2f} | Test Loss {3:.6f} |'
            .format(epoch, train_loss[epoch], test_accu[epoch],
                    test_loss[epoch]))
    print('\t#### Time Consumed: {0:.3f} second ####\n\r'.format(time.time() -
                                                                 start_time))
    utils.saveCheckpoint(args.cp_dir, args.model_name, epoch, model, optimizer,
                         test_accu, train_loss, args.bn, args.weight_decay)
    utils.plotCurve(args, train_loss / args.trn_batch, test_loss, test_accu)
Beispiel #2
0
    def train(self, x_train_lstm, x_train_svm, y_train, 
              x_val_lstm = None, x_val_svm = None, y_val = None, 
              x_train_u = None, x_val_u = None,
              n_epochs = 20, class_weight = None):
         
        logdir = os.path.join("tensorboard_log", self.detection, self.save_model_name, self.num_set)
        cmd = "mkdir -p " + logdir
        os.system(cmd)
        cmd = "rm -r " + logdir + "/"
        os.system(cmd)
        tensorboard_callback = callbacks.TensorBoard(log_dir = logdir)
        
        h5_save_path = os.path.join("Models", self.detection, self.model_name + "_" + self.num_set + ".h5")

        acc = []
        loss = []
        val_acc = []
        val_loss = []
        
        x_train_dict = {'lstm_features': x_train_lstm, 'svm_features': x_train_svm}
        x_val_dict = {'lstm_features':x_val_lstm, 'svm_features': x_val_svm}
        
        train_set = tf.data.Dataset.from_tensor_slices((x_train_dict, y_train)).shuffle(buffer_size = Config.BUFFER_SIZE).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
        val_set = tf.data.Dataset.from_tensor_slices((x_val_dict, y_val)).shuffle(buffer_size = Config.BUFFER_SIZE).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

        model_cp = callbacks.ModelCheckpoint(filepath = h5_save_path,
                                             monitor = "val_loss",
                                             verbose = 0,
                                             save_best_only = True,
                                             save_weights_only = False,
                                             mode = "min",
                                             period = 10)
        """
        history = self.model.fit(x = x_train_dict, 
                                 y = y_train, 
                                 shuffle = 1,
                                 validation_data = (x_val_dict, y_val),
                                 batch_size = Config.BATCH_SIZE, 
                                 epochs = Config.EPOCHS,
                                 class_weight = class_weight,
                                 callbacks = [model_cp, tensorboard_callback],
                                 use_multiprocessing = False)
        """
        history = self.model.fit(x = train_set,
                                 validation_data = val_set,
                                 epochs = Config.EPOCHS,
                                 verbose = 1,
                                 class_weight = class_weight,
                                 callbacks = [model_cp],
                                 use_multiprocessing = False)
 
        # 训练集上的损失值和准确率
        loss = history.history['loss']
        val_loss = history.history["val_loss"]
        
        figfile = os.path.join("Fig", self.detection, self.model_name + "_" + self.num_set + ".png")
        pickle.dump(loss, open(figfile.split(".p")[0] + "_loss.cpickle", "wb"))
        pickle.dump(val_loss, open(figfile.split(".p")[0] + "_val_loss.cpickle", "wb"))
        plotCurve(loss, val_loss, 'Model Loss', 'loss', figfile)