Ejemplo n.º 1
0
    def __init__(self, args, logger, model='top'):  # res or top
        id_str = str(datetime.datetime.now()).replace(' ',
                                                      '_').replace(':', '-')
        id_str = '-time_' + id_str.replace('.', '-')

        self.model = model
        self.model_name = self._parse_args(args)

        self.tensorboard_path = os.path.join(args.tb_path,
                                             self.model_name + id_str)
        self.logger = logger
        self.writer = SummaryWriter(self.tensorboard_path)

        if args.pretrained_model_dir == '':
            self.save_path = os.path.join(args.save_path,
                                          self.model_name + id_str)
            utils.create_save_path(self.save_path, id_str, self.logger)
        else:
            self.logger.info(
                f"Using pretrained path... \nargs.pretrained_model_dir: {args.pretrained_model_dir}"
            )
            self.save_path = os.path.join(args.save_path,
                                          args.pretrained_model_dir)

        self.new_split_type = args.dataset_split_type == 'new'

        self.logger.info("** Save path: " + self.save_path)
        self.logger.info("** Tensorboard path: " + self.tensorboard_path)
Ejemplo n.º 2
0
def main():
    args = parser.parse_args()

    # Use the 24hr datasets!

    if args.local_files:
        if parameters.LOSS.upper() == "FOCAL":
            train_data_path = parameters.LOCAL_FULL_TRAIN
            test_data_path = parameters.LOCAL_FULL_TEST
        else:
            train_data_path = parameters.LOCAL_TRAIN_FILES
            test_data_path = parameters.LOCAL_TEST_FILES
    else:
        if parameters.LOSS.upper() == "FOCAL":
            train_data_path = parameters.REMOTE_FULL_TRAIN
            test_data_path = parameters.REMOTE_FULL_TEST
        else:
            train_data_path = parameters.REMOTE_TRAIN_FILES
            test_data_path = parameters.REMOTE_TEST_FILES

    if parameters.LOSS.upper() != "FOCAL":
        train_data_path, _ = create_dataset_path(
            train_data_path,
            neg_samples=parameters.NEG_SAMPLES,
            call_repeats=parameters.CALL_REPEATS,
            shift_windows=parameters.SHIFT_WINDOWS)
        test_data_path, _ = create_dataset_path(
            test_data_path,
            neg_samples=parameters.TEST_NEG_SAMPLES,
            call_repeats=1)
    else:
        if parameters.SHIFT_WINDOWS:
            train_data_path += '_OversizeCalls'

    train_loader = get_loader_fuzzy(train_data_path,
                                    parameters.BATCH_SIZE,
                                    random_seed=parameters.DATA_LOADER_SEED,
                                    norm=parameters.NORM,
                                    scale=parameters.SCALE,
                                    shift_windows=parameters.SHIFT_WINDOWS,
                                    full_window_predict=True)
    test_loader = get_loader_fuzzy(test_data_path,
                                   parameters.BATCH_SIZE,
                                   random_seed=parameters.DATA_LOADER_SEED,
                                   norm=parameters.NORM,
                                   scale=parameters.SCALE,
                                   full_window_predict=True)

    # For now model 18 signifies this!!!
    save_path = create_save_path(
        time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), args.save_local)

    dloaders = {'train': train_loader, 'valid': test_loader}

    ## Training
    model = get_model(parameters.MODEL_ID)
    model.to(parameters.device)

    print(model)

    writer = SummaryWriter(save_path)
    writer.add_scalar('batch_size', parameters.BATCH_SIZE)
    writer.add_scalar(
        'weight_decay',
        parameters.HYPERPARAMETERS[parameters.MODEL_ID]['l2_reg'])

    loss_func, include_boundaries = get_loss()

    # Honestly probably do not need to have hyper-parameters per model, but leave it for now.
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr'],
        weight_decay=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['l2_reg'])
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr_decay_step'],
        gamma=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr_decay'])

    start_time = time.time()
    model_wts = None

    model_wts = train(dloaders,
                      model,
                      loss_func,
                      optimizer,
                      scheduler,
                      writer,
                      parameters.NUM_EPOCHS,
                      include_boundaries=include_boundaries)

    if model_wts:
        model.load_state_dict(model_wts)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_path = save_path + "/" + "model.pt"
        torch.save(model, save_path)
        print('Saved best model based on {} to path {}'.format(
            parameters.TRAIN_MODEL_SAVE_CRITERIA.upper(), save_path))
    else:
        print('For some reason I don\'t have a model to save')

    print('Training time: {:10f} minutes'.format(
        (time.time() - start_time) / 60))

    writer.close()
def main():
    args = parser.parse_args()

    # What do we need to do across all of the settings!
    # Get the data loaders!
    args = parser.parse_args()

    if args.local_files:
        train_data_path = parameters.LOCAL_TRAIN_FILES
        test_data_path = parameters.LOCAL_TEST_FILES
        full_train_path = parameters.LOCAL_FULL_TRAIN
        full_test_path = parameters.LOCAL_FULL_TEST
    else:
        train_data_path = parameters.REMOTE_TRAIN_FILES
        test_data_path = parameters.REMOTE_TEST_FILES
        full_train_path = parameters.REMOTE_FULL_TRAIN
        full_test_path = parameters.REMOTE_FULL_TEST

    if parameters.HIERARCHICAL_SHIFT_WINDOWS:
        full_train_path += '_OversizeCalls'

    model_0_train_data_path, include_boundaries = create_dataset_path(
        train_data_path,
        neg_samples=parameters.NEG_SAMPLES,
        call_repeats=parameters.CALL_REPEATS,
        shift_windows=parameters.SHIFT_WINDOWS)
    model_0_test_data_path, _ = create_dataset_path(
        test_data_path,
        neg_samples=parameters.TEST_NEG_SAMPLES,
        call_repeats=1)

    # Check if a different dataset is being used for Model_1
    model_1_train_data_path = model_0_train_data_path
    model_1_test_data_path = model_0_test_data_path
    if str(parameters.HIERARCHICAL_REPEATS).lower() != "same":
        # SHould prob just have neg samples x1 since doesnt matter!!
        model_1_train_data_path, _ = create_dataset_path(
            train_data_path,
            neg_samples=parameters.NEG_SAMPLES,
            call_repeats=parameters.HIERARCHICAL_REPEATS,
            shift_windows=parameters.HIERARCHICAL_SHIFT_WINDOWS)

    # Model 0 Loaders
    model_0_train_loader = get_loader_fuzzy(
        model_0_train_data_path,
        parameters.BATCH_SIZE,
        random_seed=parameters.DATA_LOADER_SEED,
        norm=parameters.NORM,
        scale=parameters.SCALE,
        include_boundaries=include_boundaries,
        shift_windows=parameters.SHIFT_WINDOWS,
        full_window_predict=True)
    model_0_test_loader = get_loader_fuzzy(
        model_0_test_data_path,
        parameters.BATCH_SIZE,
        random_seed=parameters.DATA_LOADER_SEED,
        norm=parameters.NORM,
        scale=parameters.SCALE,
        include_boundaries=include_boundaries,
        full_window_predict=True)

    # Model 1 Loaders
    model_1_train_loader = get_loader_fuzzy(
        model_1_train_data_path,
        parameters.BATCH_SIZE,
        random_seed=parameters.DATA_LOADER_SEED,
        norm=parameters.NORM,
        scale=parameters.SCALE,
        include_boundaries=include_boundaries,
        shift_windows=parameters.HIERARCHICAL_SHIFT_WINDOWS,
        full_window_predict=True)
    model_1_test_loader = get_loader_fuzzy(
        model_1_test_data_path,
        parameters.BATCH_SIZE,
        random_seed=parameters.DATA_LOADER_SEED,
        norm=parameters.NORM,
        scale=parameters.SCALE,
        include_boundaries=include_boundaries,
        full_window_predict=True)

    if args.models_path is None:
        save_path = create_save_path(time.strftime("%Y-%m-%d_%H:%M:%S",
                                                   time.localtime()),
                                     args.save_local,
                                     save_prefix='Hierarchical_')
    else:
        save_path = args.models_path

    # Case 1) Do the entire pipeline! Can break now the pipeline into 3 helper functions!
    if args.full_pipeline:
        # Train and save model_0
        #model_0 = train_model_0(dloaders, save_path)
        model_0 = train_model_0(model_0_train_loader, model_0_test_loader,
                                save_path)
        # Do the adversarial discovery
        adversarial_train_files, adversarial_test_files = adversarial_discovery(
            full_train_path, full_test_path, model_0, save_path)
        # Train and save model 1
        train_model_1(adversarial_train_files, adversarial_test_files,
                      model_1_train_loader, model_1_test_loader, save_path)

    # Just generate new adversarial examples
    elif args.adversarial:
        # Load model_0
        model_0_path = os.path.join(save_path, "Model_0/model.pt")
        model_0 = torch.load(model_0_path, map_location=parameters.device)
        adversarial_discovery(full_train_path, full_test_path, model_0,
                              save_path)

    # Train just model_1
    elif args.model1:
        # Read in the adversarial files
        train_adversarial_file = "model_0-False_Pos_Train.txt"
        if parameters.HIERARCHICAL_SHIFT_WINDOWS:
            train_adversarial_file = "model_0-False_Pos_Train_Shift.txt"
        adversarial_train_save_path = os.path.join(save_path,
                                                   train_adversarial_file)
        adversarial_train_files = []
        with open(adversarial_train_save_path, 'r') as f:
            files = f.readlines()
            for file in files:
                adversarial_train_files.append(file.strip())

        adversarial_test_save_path = os.path.join(
            save_path, "model_0-False_Pos_Test.txt")
        adversarial_test_files = []
        with open(adversarial_test_save_path, 'r') as f:
            files = f.readlines()
            for file in files:
                adversarial_test_files.append(file.strip())

        train_model_1(adversarial_train_files, adversarial_test_files,
                      model_1_train_loader, model_1_test_loader, save_path)

    else:
        print("Invalid running mode!")
Ejemplo n.º 4
0
def train_model(run_number):
    logger = console_logger(__name__, FLAGS.logger_level)

    logger.info("Clearning Keras session.")
    K.clear_session()
    save_path = create_save_path()
    logger.debug(f"New save path: {save_path}")

    # __INPUT PIPELINE__ #
    logger.info("Loading datasets.")
    ds_train, ds_test, num_train_batches, num_test_batches = load_datasets(
        FLAGS.load_dir)
    logger.debug(
        f"num_train_batches: {num_train_batches}, num_test_batches: {num_test_batches}"
    )

    # __MODEL__ #
    logger.info("Initializing kernel.")
    kernel_initializer = tf.initializers.TruncatedNormal(
        mean=FLAGS.weight_init_mean, stddev=FLAGS.weight_init_stddev)
    logger.info("Building model")
    model = build_model(kernel_initializer)
    #        print('Trainable params: {}'.format(model.count_params()))
    logger.info(model.summary())

    # Load model weights from checkpoint if checkpoint_path is provided
    if FLAGS.checkpoint_path:
        logger.info("Loading model weights from checkpoint.")
        model.load_weights(FLAGS.checkpoint_path)

    # save model FLAGS to save_path
    logger.info(f"Saving flags (settings) to {save_path}")
    save_config(save_path)

    # save model architecture image to save_path directory
    if FLAGS.save_architecture_image:
        logger.info(f"Saving model architecture image to {save_path}")
        tf.keras.utils.plot_model(model,
                                  os.path.join(save_path, 'architecture.png'),
                                  show_shapes=FLAGS.show_shapes)

    # __LOGGING__ #
    logger.info(f"Initializing summyary writers.")
    summary_writer_train = tf.summary.create_file_writer(save_path + '/train',
                                                         name='sw-train')
    summary_writer_test = tf.summary.create_file_writer(save_path + '/test',
                                                        name='sw-test')

    # __TRAINING__ #
    logger.info(f"Initializing variables, optimizer and metrics.")
    best_cer = 100.0
    best_epoch = -1
    optimizer = tf.keras.optimizers.Adam(lr=FLAGS.lr,
                                         epsilon=FLAGS.epsilon,
                                         amsgrad=FLAGS.amsgrad)
    train_loss = tf.keras.metrics.Mean(name='train_loss', dtype=tf.float32)
    test_loss = tf.keras.metrics.Mean(name='test_loss', dtype=tf.float32)
    test_cer = tf.keras.metrics.Mean(name='test_cer', dtype=tf.float32)
    test_ccer = tf.keras.metrics.Mean(name="test_ccer", dtype=tf.float32)

    for epoch in range(FLAGS.max_epochs):
        logger.log(35, f'_______| Run {run_number} | Epoch {epoch} |_______')

        # TRAINING DATA
        with summary_writer_train.as_default():
            logger.info(f"Training model.")
            train_fn(model, ds_train, optimizer, train_loss, num_train_batches,
                     epoch)

        # TESTING DATA
        with summary_writer_test.as_default():
            logger.info(f"Testing model.")
            _, test_mean_cer = test_fn(model, ds_test, test_loss, test_cer,
                                       test_ccer, num_test_batches, epoch)

        # EARLY STOPPING AND KEEPING THE BEST MODEL
        stop_training, best_cer, best_epoch = early_stopping(
            model, test_mean_cer, best_cer, epoch, best_epoch, save_path)
        logger.log(35, f'| Best CER {best_cer} | Best epoch {best_epoch} |')
        if stop_training:
            logger.log(35, f'Model stopped early at epoch {epoch}')
            break
Ejemplo n.º 5
0
def main():
    # What do we need to do across all of the settings!
    # Get the data loaders!
    args = parser.parse_args()

    if args.local_files:
        train_data_path = parameters.LOCAL_TRAIN_FILES
        test_data_path = parameters.LOCAL_TEST_FILES
        full_train_path = parameters.LOCAL_FULL_TRAIN
        full_test_path = parameters.LOCAL_FULL_TEST
    else:
        if parameters.DATASET.lower() == "noab":
            train_data_path = parameters.REMOTE_TRAIN_FILES
            test_data_path = parameters.REMOTE_TEST_FILES
            full_train_path = parameters.REMOTE_FULL_TRAIN
            full_test_path = parameters.REMOTE_FULL_TEST
        else:
            train_data_path = parameters.REMOTE_BAI_TRAIN_FILES
            test_data_path = parameters.REMOTE_BAI_TEST_FILES
            full_train_path = parameters.REMOTE_FULL_TRAIN_BAI
            full_test_path = parameters.REMOTE_FULL_TEST_BAI


    # Get oversized calls if shifting windows or repeating for model 2
    # We should try to remove both of these This is an issue too!
    if parameters.HIERARCHICAL_SHIFT_WINDOWS: # or parameters.HIERARCHICAL_REPEATS > 1:
        full_train_path += '_OversizeCalls'

    # For model 2 we need to have oversized calls to generate the randomly located repeats
    if parameters.HIERARCHICAL_REPEATS_POS > 1 or parameters.HIERARCHICAL_REPEATS_NEG > 1:
        full_train_path += '_OversizeCalls'

    model_0_train_data_path, include_boundaries = create_dataset_path(train_data_path, neg_samples=parameters.NEG_SAMPLES, 
                                                                    call_repeats=parameters.CALL_REPEATS, 
                                                                    shift_windows=parameters.SHIFT_WINDOWS)
    model_0_test_data_path, _ = create_dataset_path(test_data_path, neg_samples=parameters.TEST_NEG_SAMPLES, 
                                                                call_repeats=1)
    

    # Check if a different dataset is being used for Model_1
    model_1_train_data_path = model_0_train_data_path
    model_1_test_data_path = model_0_test_data_path
    # Remove this same thing!
    #if str(parameters.HIERARCHICAL_REPEATS).lower() != "same" or parameters.HIERARCHICAL_REPEATS_POS > 1  or parameters.HIERARCHICAL_REPEATS_NEG > 1:
    if parameters.HIERARCHICAL_REPEATS_POS > 1  or parameters.HIERARCHICAL_REPEATS_NEG > 1:
        # SHould prob just have neg samples x1 since doesnt matter!!
        # For now set call repeats to 1, but get shifting windows so we later can do call repeats!
        #shift_windows = parameters.HIERARCHICAL_REPEATS > 1 or parameters.HIERARCHICAL_SHIFT_WINDOWS
        # For now should make shift windows just be true! Because it does not make a lot of sense to do 
        # repeats without shifting windows since we can only repeat the pos examples
        shift_windows = True
        # Set this to 1 because we take care of this later!!!!
        call_repeats = 1
        model_1_train_data_path, _ = create_dataset_path(train_data_path, neg_samples=parameters.NEG_SAMPLES, 
                                                        call_repeats=call_repeats,
                                                        shift_windows=shift_windows)
    
    
    # Model 0 Loaders
    model_0_train_loader = get_loader_fuzzy(model_0_train_data_path, parameters.BATCH_SIZE, random_seed=parameters.DATA_LOADER_SEED, 
                                        norm=parameters.NORM, scale=parameters.SCALE, 
                                        include_boundaries=include_boundaries, shift_windows=parameters.SHIFT_WINDOWS)
    model_0_test_loader = get_loader_fuzzy(model_0_test_data_path, parameters.BATCH_SIZE, random_seed=parameters.DATA_LOADER_SEED, 
                                        norm=parameters.NORM, scale=parameters.SCALE, include_boundaries=include_boundaries)
    
    # Model 1 Loaders
    model_1_train_loader = get_loader_fuzzy(model_1_train_data_path, parameters.BATCH_SIZE, random_seed=parameters.DATA_LOADER_SEED, 
                                        norm=parameters.NORM, scale=parameters.SCALE, 
                                        include_boundaries=include_boundaries, shift_windows=parameters.HIERARCHICAL_SHIFT_WINDOWS)
    model_1_test_loader = get_loader_fuzzy(model_1_test_data_path, parameters.BATCH_SIZE, random_seed=parameters.DATA_LOADER_SEED, 
                                        norm=parameters.NORM, scale=parameters.SCALE, include_boundaries=include_boundaries)
    

    if args.path is None:
        save_path = create_save_path(time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), args.save_local, save_prefix='Hierarchical_')
    else:
        save_path = args.path

    # Case 1) Do the entire pipeline! Can break now the pipeline into 3 helper functions!
    if args.full_pipeline:
        # Train and save model_0
        if args.model_0 == None:
            model_0 = train_model_0(model_0_train_loader, model_0_test_loader , save_path, args.pre_train_0)
        else: # Load and save model_0
            model_0 = torch.load(args.model_0, map_location=parameters.device)
            first_model_save_path = os.path.join(save_path, "Model_0")
            if not os.path.exists(first_model_save_path):
                os.makedirs(first_model_save_path)

            model_save_path = os.path.join(first_model_save_path, "model.pt")
            torch.save(model_0, model_save_path)

        # Do the adversarial discovery
        adversarial_train_files, adversarial_test_files = adversarial_discovery(full_train_path, full_test_path, 
                                                                model_0_train_loader, model_0_test_loader, model_0, save_path)
        # Train and save model 1
        train_model_1(adversarial_train_files, adversarial_test_files, model_1_train_loader, model_1_test_loader, 
                                                save_path, args.pre_train_1)
    
    # Just generate new adversarial examples 
    elif args.adversarial:
        # Load model_0
        model_0_path = os.path.join(save_path, "Model_0/model.pt")
        model_0 = torch.load(model_0_path, map_location=parameters.device)
        adversarial_discovery(full_train_path, full_test_path, model_0_train_loader, model_0_test_loader, model_0, save_path)

    # Train just model_1
    elif args.model1:
        # Read in the adversarial files
        train_adversarial_file = "model_0-False_Pos_Train.txt"
        #if parameters.HIERARCHICAL_SHIFT_WINDOWS or parameters.HIERARCHICAL_REPEATS > 1:
        if parameters.HIERARCHICAL_REPEATS_POS > 1  or parameters.HIERARCHICAL_REPEATS_NEG > 1:
            train_adversarial_file = "model_0-False_Pos_Train_Shift.txt"

        adversarial_train_save_path = os.path.join(save_path, train_adversarial_file)
        adversarial_train_files = []
        with open(adversarial_train_save_path, 'r') as f:
            files = f.readlines()
            for file in files:
                adversarial_train_files.append(file.strip())

        adversarial_test_save_path = os.path.join(save_path, "model_0-False_Pos_Test.txt")
        adversarial_test_files = []
        with open(adversarial_test_save_path, 'r') as f:
            files = f.readlines()
            for file in files:
                adversarial_test_files.append(file.strip())

        train_model_1(adversarial_train_files, adversarial_test_files, model_1_train_loader, 
                        model_1_test_loader, save_path, args.pre_train_1)

    elif args.visualize:
        model_0_path = os.path.join(save_path, "Model_0/model.pt")
        model_0 = torch.load(model_0_path, map_location=parameters.device)
        model_1_name = hierarchical_model_1_path()
        model_1_path = os.path.join(save_path, model_1_name+'/model.pt')
        model_1 = torch.load(model_1_path, map_location=parameters.device)

        # Read in the adversarial files
        train_adversarial_file = "model_0-False_Pos_Train.txt"
        #if parameters.HIERARCHICAL_SHIFT_WINDOWS or parameters.HIERARCHICAL_REPEATS > 1:
        if parameters.HIERARCHICAL_REPEATS_POS > 1  or parameters.HIERARCHICAL_REPEATS_NEG > 1:
            train_adversarial_file = "model_0-False_Pos_Train_Shift.txt"

        adversarial_train_save_path = os.path.join(save_path, train_adversarial_file)
        adversarial_train_files = []
        with open(adversarial_train_save_path, 'r') as f:
            files = f.readlines()
            for file in files:
                adversarial_train_files.append(file.strip())

        visualize_adversarial(adversarial_train_files, model_1_train_loader, model_0, model_1)

    else:
        print ("Invalid running mode!")
def main():
    args = parser.parse_args()


    if args.local_files:
        train_data_path = parameters.LOCAL_TRAIN_FILES
        test_data_path = parameters.LOCAL_TEST_FILES
        full_train_path = parameters.LOCAL_FULL_TRAIN
        full_test_path = parameters.LOCAL_FULL_TEST
    else:
        if parameters.DATASET.lower() == "noab":
            train_data_path = parameters.REMOTE_TRAIN_FILES
            test_data_path = parameters.REMOTE_TEST_FILES
            full_train_path = parameters.REMOTE_FULL_TRAIN
            full_test_path = parameters.REMOTE_FULL_TEST
        else:
            train_data_path = parameters.REMOTE_BAI_TRAIN_FILES
            test_data_path = parameters.REMOTE_BAI_TEST_FILES
            full_train_path = parameters.REMOTE_FULL_TRAIN_BAI
            full_test_path = parameters.REMOTE_FULL_TEST_BAI

    
    
    train_data_path, include_boundaries = create_dataset_path(train_data_path, neg_samples=parameters.NEG_SAMPLES, 
                                                                    call_repeats=parameters.CALL_REPEATS, 
                                                                    shift_windows=parameters.SHIFT_WINDOWS)
    test_data_path, _ = create_dataset_path(test_data_path, neg_samples=parameters.TEST_NEG_SAMPLES, 
                                                                call_repeats=1)
    
    train_loader = get_loader_fuzzy(train_data_path, parameters.BATCH_SIZE, random_seed=parameters.DATA_LOADER_SEED, 
                                        norm=parameters.NORM, scale=parameters.SCALE, 
                                        include_boundaries=include_boundaries, shift_windows=parameters.SHIFT_WINDOWS)
    test_loader = get_loader_fuzzy(test_data_path, parameters.BATCH_SIZE, random_seed=parameters.DATA_LOADER_SEED, 
                                        norm=parameters.NORM, scale=parameters.SCALE, include_boundaries=include_boundaries)

    # For now we don't need to save the model
    save_path = create_save_path(time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), args.save_local)

    train_dataloaders = {'train':train_loader, 'valid':test_loader}

    # Load the full data sets - SET SHUFFLE = False
    full_train_loader = get_loader_fuzzy(full_train_path, parameters.BATCH_SIZE, shuffle=False, 
                                        norm=parameters.NORM, scale=parameters.SCALE, 
                                        include_boundaries=False, shift_windows=False,
                                        is_full_dataset=True)
    full_test_loader = get_loader_fuzzy(full_test_path, parameters.BATCH_SIZE, shuffle=False, 
                                        norm=parameters.NORM, scale=parameters.SCALE, include_boundaries=False)
    full_dataloaders = {'train':full_train_loader, 'valid': full_test_loader}

    
    model = get_model(parameters.MODEL_ID)
    model.to(parameters.device)

    print(model)

    writer = SummaryWriter(save_path)
    writer.add_scalar('batch_size', parameters.BATCH_SIZE)
    writer.add_scalar('weight_decay', parameters.HYPERPARAMETERS[parameters.MODEL_ID]['l2_reg'])

    # Want to use focal loss! Next thing to check on!
    loss_func, include_boundaries = get_loss()

    # Honestly probably do not need to have hyper-parameters per model, but leave it for now.
    optimizer = torch.optim.Adam(model.parameters(), lr=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr'],
                                 weight_decay=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['l2_reg'])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr_decay_step'], 
                                            gamma=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr_decay'])

    start_time = time.time()

    curriculum_profiling(model, train_dataloaders, full_dataloaders, loss_func, optimizer, scheduler, writer)

    print('Training time: {:10f} minutes'.format((time.time()-start_time)/60))

    writer.close()
Ejemplo n.º 7
0
def main():
    args = parser.parse_args()

    use_focal = False
    if parameters.LOSS.lower() == "focal" or parameters.LOSS.lower(
    ) == "focal_chunk":
        use_focal = True

    if args.local_files:
        if use_focal:
            train_data_path = parameters.LOCAL_FULL_TRAIN
            test_data_path = parameters.LOCAL_FULL_TEST
        else:
            train_data_path = parameters.LOCAL_TRAIN_FILES
            test_data_path = parameters.LOCAL_TEST_FILES
    else:
        if parameters.DATASET.lower() == "noab":
            if use_focal:
                train_data_path = parameters.REMOTE_FULL_TRAIN
                test_data_path = parameters.REMOTE_FULL_TEST
            else:
                train_data_path = parameters.REMOTE_TRAIN_FILES
                test_data_path = parameters.REMOTE_TEST_FILES
        else:
            if use_focal:
                train_data_path = parameters.REMOTE_FULL_TRAIN_BAI
                test_data_path = parameters.REMOTE_FULL_TEST_BAI
            else:
                train_data_path = parameters.REMOTE_BAI_TRAIN_FILES
                test_data_path = parameters.REMOTE_BAI_TEST_FILES

    if use_focal:
        include_boundaries = False
    else:
        train_data_path, include_boundaries = create_dataset_path(
            train_data_path,
            neg_samples=parameters.NEG_SAMPLES,
            call_repeats=parameters.CALL_REPEATS,
            shift_windows=parameters.SHIFT_WINDOWS)
        test_data_path, _ = create_dataset_path(
            test_data_path,
            neg_samples=parameters.TEST_NEG_SAMPLES,
            call_repeats=1)

    train_loader = get_loader_fuzzy(train_data_path,
                                    parameters.BATCH_SIZE,
                                    random_seed=parameters.DATA_LOADER_SEED,
                                    norm=parameters.NORM,
                                    scale=parameters.SCALE,
                                    include_boundaries=include_boundaries,
                                    shift_windows=parameters.SHIFT_WINDOWS)
    test_loader = get_loader_fuzzy(test_data_path,
                                   parameters.BATCH_SIZE,
                                   random_seed=parameters.DATA_LOADER_SEED,
                                   norm=parameters.NORM,
                                   scale=parameters.SCALE,
                                   include_boundaries=include_boundaries)

    save_path = create_save_path(
        time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), args.save_local)

    dloaders = {'train': train_loader, 'valid': test_loader}

    ## Training
    # Load a pre-trained model
    if parameters.PRE_TRAIN:
        model = torch.load(args.pre_train, map_location=parameters.device)
    else:
        model = get_model(parameters.MODEL_ID)
        model.to(parameters.device)

    print(model)

    writer = SummaryWriter(save_path)
    writer.add_scalar('batch_size', parameters.BATCH_SIZE)
    writer.add_scalar(
        'weight_decay',
        parameters.HYPERPARAMETERS[parameters.MODEL_ID]['l2_reg'])

    # Want to use focal loss! Next thing to check on!
    loss_func, include_boundaries = get_loss()

    # Honestly probably do not need to have hyper-parameters per model, but leave it for now.
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr'],
        weight_decay=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['l2_reg'])
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr_decay_step'],
        gamma=parameters.HYPERPARAMETERS[parameters.MODEL_ID]['lr_decay'])

    start_time = time.time()
    model_wts = None

    model_wts = train(dloaders,
                      model,
                      loss_func,
                      optimizer,
                      scheduler,
                      writer,
                      parameters.NUM_EPOCHS,
                      include_boundaries=include_boundaries)

    if model_wts:
        model.load_state_dict(model_wts)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_path = save_path + "/" + "model.pt"
        torch.save(model, save_path)
        print('Saved best model based on {} to path {}'.format(
            parameters.TRAIN_MODEL_SAVE_CRITERIA.upper(), save_path))
    else:
        print('For some reason I don\'t have a model to save')

    print('Training time: {:10f} minutes'.format(
        (time.time() - start_time) / 60))

    writer.close()