Ejemplo n.º 1
0
def main(cli_args=None):
    utilities.update_trial_number()
    utilities.create_output_directory()

    logger = logging.getLogger('logger')
    parser = argparse.ArgumentParser()
    utilities.configure_parser(parser)
    utilities.configure_logger(logger)
    if cli_args is None:
        cli_args = parser.parse_args()

    logger.info('Beginning trial #{0}...'.format(utilities.get_trial_number()))
    log_cli_args(cli_args)
    try:
        datafiles = VocalImitation(recalculate_spectrograms=cli_args.recalculate_spectrograms)
        data_split = PartitionSplit(*cli_args.partitions)
        partitions = Partitions(datafiles, data_split, cli_args.num_categories, regenerate=False)
        partitions.generate_partitions(PairPartition, no_test=True)
        partitions.save("./output/{0}/partition.pickle".format(utilities.get_trial_number()))

        if cli_args.triplet:
            model = Triplet(dropout=cli_args.dropout)
        elif cli_args.pairwise:
            model = Siamese(dropout=cli_args.dropout)
        else:
            raise ValueError("You must specify the type of the model that is to be evaluated (triplet or pairwise")

        if cli_args.cuda:
            model = model.cuda()

        evaluated_epochs = np.arange(0, 300, step=5)
        model_directory = './model_output/{0}'.format('pairwise' if cli_args.pairwise else 'triplet') + '/model_{0}'
        model_paths = [model_directory.format(n) for n in evaluated_epochs]
        n_memorized = []
        memorized_var = []
        for model_path in model_paths:
            utils.network.load_model(model, model_path, cli_args.cuda)
            n, v = num_memorized_canonicals(model if cli_args.pairwise else model.siamese, AllPairs(partitions.train),
                                            cli_args.cuda)
            logger.info("n = {0}\nv={1}".format(n, v))
            n_memorized.append(n)
            memorized_var.append(v)

            num_canonical_memorized(memorized_var, n_memorized, evaluated_epochs[:len(n_memorized)], cli_args.num_categories)

    except Exception as e:
        logger.critical("Unhandled exception: {0}".format(str(e)))
        logger.critical(traceback.print_exc())
        sys.exit()
Ejemplo n.º 2
0
def title_to_filename(title, suffix):
    if suffix:
        file = suffix + '_' + title
    else:
        file = title
    file = file.replace(' ', '_').replace('.', '').replace(',', '')
    file += '.png'
    file = file.lower()
    return os.path.join('./output', str(utilities.get_trial_number()), file)
Ejemplo n.º 3
0
def initialize_siamese_params(regenerate, dropout):
    logger = logging.getLogger('logger')
    starting_weights_path = "./model_output/siamese_init/starting_weights"

    model = Siamese(dropout=dropout)
    if not regenerate:
        load_model(model, starting_weights_path)

    logger.debug("Saving initial weights/biases at {0}...".format(starting_weights_path))
    save_model(model, starting_weights_path)

    trial_path = "./output/{0}/init_weights".format(get_trial_number())
    logger.debug("Saving initial weights/biases at {0}...".format(trial_path))
    save_model(model, trial_path)
Ejemplo n.º 4
0
    def graph(self, trial_name, search_length):
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

        fig.set_size_inches(16, 10)
        mean_rank_per_epoch(self.train_rank, self.val_rank, search_length, ax1)
        mrr_per_epoch(self.train_mrr,
                      self.val_mrr,
                      ax2,
                      n_categories=search_length)
        loss_per_epoch(self.train_loss, self.val_loss, ax3, log=True)
        loss_per_epoch(self.train_loss, self.val_loss, ax4, log=False)

        fig.suptitle("{0}, Trial #{1}".format(trial_name, get_trial_number()))
        fig.savefig(self.filename(trial_name), dpi=200)
        plt.close()
Ejemplo n.º 5
0
def train(use_cuda: bool, n_epochs: int, validate_every: int,
          use_dropout: bool, partitions: Partitions, optimizer_name: str,
          lr: float, wd: float, momentum: bool):
    logger = logging.getLogger('logger')

    no_test = True
    model_path = "./model_output/pairwise/model_{0}"

    partitions.generate_partitions(PairPartition, no_test=no_test)
    training_data = Balanced(partitions.train)

    if validate_every > 0:
        balanced_validation = Balanced(partitions.val)
        training_pairs = AllPairs(partitions.train)
        search_length = training_pairs.n_references
        validation_pairs = AllPairs(partitions.val)
        testing_pairs = AllPairs(partitions.test) if not no_test else None
    else:
        balanced_validation = None
        training_pairs = None
        validation_pairs = None
        testing_pairs = None
        search_length = None

    # get a siamese network, see Siamese class for architecture
    siamese = Siamese(dropout=use_dropout)
    siamese = initialize_weights(siamese, use_cuda)

    if use_cuda:
        siamese = siamese.cuda()

    criterion = BCELoss()
    optimizer = get_optimizer(siamese, optimizer_name, lr, wd, momentum)

    try:
        logger.info("Training network with pairwise loss...")
        progress = TrainingProgress()
        models = training.train_siamese_network(siamese, training_data,
                                                criterion, optimizer, n_epochs,
                                                use_cuda)
        for epoch, (model, training_batch_losses) in enumerate(models):
            utils.network.save_model(model, model_path.format(epoch))

            training_loss = training_batch_losses.mean()
            if validate_every != 0 and epoch % validate_every == 0:
                validation_batch_losses = inference.siamese_loss(
                    model, balanced_validation, criterion, use_cuda)
                validation_loss = validation_batch_losses.mean()

                training_mrr, training_rank = inference.mean_reciprocal_ranks(
                    model, training_pairs, use_cuda)
                val_mrr, val_rank = inference.mean_reciprocal_ranks(
                    model, validation_pairs, use_cuda)

                progress.add_mrr(train=training_mrr, val=val_mrr)
                progress.add_rank(train=training_rank, val=val_rank)
                progress.add_loss(train=training_loss, val=validation_loss)
            else:
                progress.add_mrr(train=np.nan, val=np.nan)
                progress.add_rank(train=np.nan, val=np.nan)
                progress.add_loss(train=training_loss, val=np.nan)

            progress.graph("Siamese", search_length)

        # load weights from best model if we validated throughout
        if validate_every > 0:
            siamese = siamese.train()
            utils.network.load_model(
                siamese, model_path.format(np.argmax(progress.val_mrr)))

        # otherwise just save most recent model
        utils.network.save_model(siamese, model_path.format('best'))
        utils.network.save_model(
            siamese,
            './output/{0}/pairwise'.format(utilities.get_trial_number()))

        if not no_test:
            logger.info(
                "Results from best model generated during training, evaluated on test data:"
            )
            rrs = inference.reciprocal_ranks(siamese, testing_pairs, use_cuda)
            utilities.log_final_stats(rrs)

        progress.pearson(log=True)
        progress.save("./output/{0}/pairwise.pickle".format(
            utilities.get_trial_number()))
        return siamese
    except Exception as e:
        utils.network.save_model(siamese, model_path.format('crash_backup'))
        logger.critical("Exception occurred while training: {0}".format(
            str(e)))
        logger.critical(traceback.print_exc())
        sys.exit()
Ejemplo n.º 6
0
def main(cli_args=None):
    utilities.update_trial_number()
    utilities.create_output_directory()

    logger = logging.getLogger('logger')
    parser = argparse.ArgumentParser()
    utilities.configure_parser(parser)
    utilities.configure_logger(logger)
    if cli_args is None:
        cli_args = parser.parse_args()

    logger.info('Beginning trial #{0}...'.format(utilities.get_trial_number()))
    log_cli_args(cli_args)
    try:
        if cli_args.dataset in ['vs1.0']:
            dataset = VocalSketch_1_0
        elif cli_args.dataset in ['vs1.1']:
            dataset = VocalSketch_1_1
        elif cli_args.dataset in ['vi']:
            dataset = VocalImitation
        else:
            raise ValueError("Invalid dataset ({0}) chosen.".format(
                cli_args.siamese_dataset))

        # imitation_augmentations, reference_augmentations = get_augmentation_chains()

        datafiles = dataset(
            recalculate_spectrograms=cli_args.recalculate_spectrograms,
            imitation_augmentations=None,
            reference_augmentations=None)

        data_split = PartitionSplit(*cli_args.partitions)
        partitions = Partitions(datafiles,
                                data_split,
                                cli_args.num_categories,
                                regenerate=cli_args.regenerate_splits
                                or cli_args.recalculate_spectrograms)
        partitions.save("./output/{0}/partition.pickle".format(
            utilities.get_trial_number()))

        utils.network.initialize_siamese_params(cli_args.regenerate_weights,
                                                cli_args.dropout)

        if cli_args.triplet:
            experiments.triplet.train(cli_args.cuda, cli_args.epochs,
                                      cli_args.validation_frequency,
                                      cli_args.dropout, partitions,
                                      cli_args.optimizer,
                                      cli_args.learning_rate,
                                      cli_args.weight_decay, cli_args.momentum)

        if cli_args.pairwise:
            experiments.pairwise.train(
                cli_args.cuda, cli_args.epochs, cli_args.validation_frequency,
                cli_args.dropout, partitions, cli_args.optimizer,
                cli_args.learning_rate, cli_args.weight_decay,
                cli_args.momentum)

        cli_args.trials -= 1
        if cli_args.trials > 0:
            main(cli_args)
    except Exception as e:
        logger.critical("Unhandled exception: {0}".format(str(e)))
        logger.critical(traceback.print_exc())
        sys.exit()
Ejemplo n.º 7
0
 def filename(title):
     file = title.replace(' ', '_').replace('.', '').replace(',', '')
     file += '.png'
     file = file.lower()
     return os.path.join('./output', str(get_trial_number()), file)