def main():
    """
    Starting point of the application
    """
    hvd.init()
    params = parse_args(PARSER.parse_args())
    set_flags(params)
    model_dir = prepare_model_dir(params)
    params.model_dir = model_dir
    logger = get_logger(params)

    model = Unet()

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.crossvalidation_idx,
                      augment=params.augment,
                      gpu_id=hvd.rank(),
                      num_gpus=hvd.size(),
                      seed=params.seed)

    if 'train' in params.exec_mode:
        train(params, model, dataset, logger)

    if 'evaluate' in params.exec_mode:
        if hvd.rank() == 0:
            evaluate(params, model, dataset, logger)

    if 'predict' in params.exec_mode:
        if hvd.rank() == 0:
            predict(params, model, dataset, logger)
Esempio n. 2
0
def main(args):
    # Setting up an experiment
    config, params = setup(args)

    # Setting up logger
    logger = get_logger(config['model_name'], config['dirs']['logs_dir'])

    # Extracting configurations
    data_config = config['data']
    logs_config = config['logs']
    training_config = config['training']
    sampling_config = config['sampling']
    dirs_config = config['dirs']
    logger.info('[SETUP] Experiment configurations')
    logger.info(
        f'[SETUP] Experiment directory: {os.path.abspath(dirs_config["exp_dir"])}'
    )

    # Loading the dataset
    (X_train, len_train), (X_valid, len_valid), (_, _) = load_data(
        data_config=data_config, step_size=training_config['num_pixels'])
    logger.info(f'[LOAD]  Dataset (shape: {X_train[0].shape})')

    # Computing beat size in time steps
    beat_size = float(data_config['beat_resolution'] /
                      training_config['num_pixels'])

    # Preparing inputs for sampling
    intro_songs, save_ids, song_labels = prepare_sampling_inputs(
        X_train, X_valid, sampling_config, beat_size)
    num_save_intro = len(save_ids) // sampling_config['num_save']
    logger.info('[SETUP] Inputs for sampling')

    # Creating the MultINN model
    tf.reset_default_graph()
    model = MultINN(config,
                    params,
                    mode=params['mode'],
                    name=config['model_name'])
    logger.info('[BUILT] Model')

    # Building the sampler and evaluator
    sampler = model.sampler(num_beats=sampling_config['sample_beats'])
    logger.info('[BUILT] Sampler')
    evaluator = model.evaluator()
    logger.info('[BUILT] Evaluator')

    # Building optimizer and training ops
    if args.sgd:
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=training_config['learning_rate'])
    else:
        optimizer = tf.train.AdamOptimizer(
            learning_rate=training_config['learning_rate'], epsilon=1e-4)

    init_ops, update_ops, metrics, metrics_upd, summaries = model.train_generators(
        optimizer=optimizer, lr=training_config['learning_rate'])
    logger.info('[BUILT] Optimizer and update ops')

    # Extracting placeholders, metrics and summaries
    placeholders = model.placeholders
    x, lengths, is_train = placeholders['x'], placeholders[
        'lengths'], placeholders['is_train']

    loss = metrics['batch/loss']
    loglik, global_loglik = metrics['log_likelihood'], metrics['global'][
        'log_likelihood']

    weights_sum, metrics_sum, gradients_sum = summaries['weights'], summaries[
        'metrics'], summaries['gradients']

    # TensorFlow Session set up
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf.set_random_seed(training_config['random_seed'])
    np.random.seed(training_config['random_seed'])

    with tf.Session(config=tf_config) as sess:
        logger.info('[START] TF Session')
        with tf.variable_scope('init_global'):
            init_global = tf.global_variables_initializer()
        with tf.variable_scope('init_local'):
            init_local = tf.local_variables_initializer()
        sess.run([init_global, init_local])

        stats = TrainingStats()

        # Loading the model's weights or using initial weights
        if not args.from_init:
            if args.from_last:
                if model.load(sess, dirs_config['model_last_dir']):
                    last_stats_file = os.path.join(
                        dirs_config['model_last_dir'], 'steps')
                    if os.path.isfile(last_stats_file):
                        stats.load(last_stats_file)
                        logger.info('[LOAD]  Training stats file')

                    logger.info(
                        f'[LOAD]  Pre-trained weights (last, epoch={stats.epoch})'
                    )
                else:
                    logger.info('[LOAD]  Initial weights')
            elif model.load(sess, dirs_config['model_dir']):
                if os.path.isfile(dirs_config['model_stats_file']):
                    stats.load(dirs_config['model_stats_file'])
                    logger.info('[LOAD]  Training file')

                logger.info(
                    f'[LOAD]  Pre-trained weights (best, epoch={stats.epoch})')
            else:
                logger.info('[LOAD]  Initial weights')

                # run initialization update if exists
                if init_ops:
                    sess.run(init_ops, feed_dict={x: X_train[:1600]})
                    logger.info('[END]   Run initialization ops')
        else:
            logger.info('[LOAD]  Initial weights')

        if args.encoders and params['encoder']['type'] != 'Pass':
            encoder_dir = os.path.join(args.encoders, 'ckpt', 'encoders')
            if model.load_encoders(sess, os.path.join(encoder_dir)):
                logger.info('[LOAD]  Encoders\' weights')
            else:
                logger.info('[WARN]  Failed to load encoders\' weights')

        stats.new_run()

        # Preparing to the training
        graph = sess.graph if logs_config['save_graph'] else None
        writer_train = tf.summary.FileWriter(
            f'{dirs_config["logs_dir"]}/Graph/run_{stats.run}/train', graph)
        writer_valid = tf.summary.FileWriter(
            f'{dirs_config["logs_dir"]}/Graph/run_{stats.run}/valid')

        batch_size = training_config['batch_size']
        piece_size = int(training_config['piece_size'] * beat_size)

        logger.info(f'[START] Training, RUN={stats.run}')
        ids = np.arange(X_train.shape[0])

        # Logging initial weights
        if logs_config['log_weights_steps'] > 0:
            writer_train.add_summary(sess.run(weights_sum), stats.steps)
            logger.info('[LOG]   Initial weights')

        loss_accum = LossAccumulator()

        # Training on all of the songs `num_epochs` times
        past_epochs = stats.epoch
        for epoch in range(past_epochs + 1,
                           past_epochs + training_config['epochs'] + 1):
            stats.new_epoch()
            tf.set_random_seed(epoch)
            np.random.seed(epoch)

            start = time.time()

            np.random.shuffle(ids)
            loss_accum.clear()
            base_info = f'\r epoch: {epoch:3d} '

            for i in range(0, X_train.shape[0], batch_size):
                for j in range(0, X_train.shape[1], piece_size):
                    len_batch = len_train[ids[i:i + batch_size]] - j
                    non_empty = np.where(len_batch > 0)[0]

                    if len(non_empty) > 0:
                        len_batch = np.minimum(len_batch[non_empty],
                                               piece_size)
                        max_length = len_batch.max()

                        songs_batch = X_train[ids[i:i + batch_size],
                                              j:j + max_length, ...][non_empty]

                        if logs_config['log_weights_steps'] > 0 \
                                and (stats.steps + 1) % logs_config['log_weights_steps'] == 0 \
                                and j + piece_size >= X_train.shape[1]:
                            _, loss_i, summary = sess.run(
                                [update_ops, loss, weights_sum],
                                feed_dict={
                                    x: songs_batch,
                                    lengths: len_batch,
                                    is_train: True
                                })

                            writer_train.add_summary(summary, stats.steps + 1)
                            del summary
                        else:
                            _, loss_i = sess.run([update_ops, loss],
                                                 feed_dict={
                                                     x: songs_batch,
                                                     lengths: len_batch,
                                                     is_train: True
                                                 })

                        del songs_batch
                        loss_accum.update(loss_i)

                stats.new_step()

                # Log the progress during training
                if logs_config[
                        'log_loss_steps'] > 0 and stats.steps % logs_config[
                            'log_loss_steps'] == 0:
                    info = f' (steps: {stats.steps:5d}) time: {time_to_str(time.time() - start)}' + str(
                        loss_accum)
                    sys.stdout.write(base_info + info)
                    sys.stdout.flush()

            info = f' (steps: {stats.steps:5d})  time: {time_to_str(time.time() - start)}\n' + str(
                loss_accum)
            logger.info(base_info + info)
            logger.info(
                f'[END]   Epoch training time {time_to_str(time.time() - start)}'
            )

            # Evaluating the model on the training and validation data
            if logs_config['evaluate_epochs'] > 0 and epoch % logs_config[
                    'evaluate_epochs'] == 0:
                num_eval = X_valid.shape[0]

                collect_metrics(sess,
                                metrics_upd,
                                data=X_train[:num_eval, ...],
                                data_lengths=len_train[:num_eval, ...],
                                placeholders=placeholders,
                                batch_size=batch_size * 2,
                                piece_size=piece_size)
                summary, loglik_val, gl_loglik_val = sess.run(
                    [metrics_sum, loglik, global_loglik])
                writer_train.add_summary(summary, epoch)
                del summary
                logger.info(
                    f'[EVAL]  Training   set log-likelihood:  '
                    f'gen.={loglik_val:7.3f}  enc.={gl_loglik_val:7.3f}')

                collect_metrics(sess,
                                metrics_upd,
                                data=X_valid,
                                data_lengths=len_valid,
                                placeholders=placeholders,
                                batch_size=batch_size * 2,
                                piece_size=piece_size)
                summary, loglik_val, gl_loglik_val = sess.run(
                    [metrics_sum, loglik, global_loglik])
                writer_valid.add_summary(summary, epoch)
                del summary
                logger.info(
                    f'[EVAL]  Validation set log-likelihood:  '
                    f'gen.={loglik_val:7.3f}  enc.={gl_loglik_val:7.3f}')

            # Sampling input using the model
            if logs_config['generate_epochs'] > 0 and epoch % logs_config[
                    'generate_epochs'] == 0:
                samples = generate_music(
                    sess,
                    sampler,
                    intro_songs,
                    placeholders,
                    num_songs=sampling_config['num_songs'])
                logger.info('[EVAL]  Generated music samples')

                summary_sample = sess.run(evaluator,
                                          feed_dict={
                                              x: samples,
                                              is_train: False
                                          })
                writer_train.add_summary(summary_sample, epoch)
                del summary_sample
                logger.info('[EVAL]  Evaluated music samples')

                samples_to_save = samples[save_ids]
                del samples
                samples_to_save = pad_to_midi(samples_to_save, data_config)

                # Saving the music
                if logs_config[
                        'save_samples_epochs'] > 0 and epoch % logs_config[
                            'save_samples_epochs'] == 0:
                    save_music(samples_to_save,
                               num_intro=num_save_intro,
                               data_config=data_config,
                               base_path=f'{model.name}_e{epoch}',
                               save_dir=dirs_config['samples_dir'],
                               song_labels=song_labels)
                    logger.info('[SAVE]  Saved music samples')

            # Saving the model if the monitored metric decreased
            if loglik_val < stats.metric_best:
                stats.update_metric_best(loglik_val)
                stats.reset_idle_epochs()

                if logs_config['generate_epochs'] > 0 and epoch % logs_config[
                        'generate_epochs'] == 0:
                    save_music(samples_to_save,
                               num_intro=num_save_intro,
                               data_config=data_config,
                               base_path=f'{model.name}_best',
                               save_dir=dirs_config['samples_dir'],
                               song_labels=song_labels)

                if logs_config[
                        'save_checkpoint_epochs'] > 0 and epoch % logs_config[
                            'save_checkpoint_epochs'] == 0:
                    model.save(sess,
                               dirs_config['model_dir'],
                               global_step=stats.steps)
                    stats.save(dirs_config['model_stats_file'])

                    logger.info(
                        f'[SAVE]  Saved model after {epoch} epoch(-s) ({stats.steps} steps)'
                    )
            else:
                stats.new_idle_epoch()

                if stats.idle_epochs >= training_config['early_stopping']:
                    # Early stopping after no improvement
                    logger.info(
                        f'[WARN]  No improvement after {training_config["early_stopping"]} epochs, quiting'
                    )

                    save_music(samples_to_save,
                               num_intro=num_save_intro,
                               data_config=data_config,
                               base_path=f'{model.name}_last',
                               save_dir=dirs_config['samples_dir'],
                               song_labels=song_labels)

                    break

            del samples_to_save
            logger.info(
                f'[END]   Epoch time {time_to_str(time.time() - start)}')

        if not args.save_best_only:
            model.save(sess,
                       dirs_config['model_last_dir'],
                       global_step=stats.steps)
            stats.save(os.path.join(dirs_config['model_last_dir'], 'steps'))
            logger.info(
                f'[SAVE]  Saved model after {epoch} epoch(-s) ({stats.steps} steps)'
            )

        writer_train.close()
        writer_valid.close()
Esempio n. 3
0
def main():
    # Setup.
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./config/example.yaml")
    parser.add_argument("--gpu", default="0", type=str)
    # Path to checkpoint (empty string means the latest checkpoint)
    # or False (means training from scratch).
    parser.add_argument("--resume", default="", type=str)
    args = parser.parse_args()
    config, inner_dir, config_name = load_config(args.config)
    saved_dir = get_saved_dir(config, inner_dir, config_name, args.resume)
    storage_dir, ckpt_dir = get_storage_dir(config, inner_dir, config_name,
                                            args.resume)
    logger = get_logger(saved_dir, "adv_training.log", args.resume)

    # Prepare data.
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([transforms.ToTensor()])
    train_data = cifar.CIFAR10(root=config["dataset_dir"],
                               transform=train_transform)
    test_data = cifar.CIFAR10(root=config["dataset_dir"],
                              train=False,
                              transform=test_transform)
    train_loader = DataLoader(train_data,
                              batch_size=config["batch_size"],
                              shuffle=True,
                              num_workers=4)
    test_loader = DataLoader(test_data,
                             batch_size=config["batch_size"],
                             num_workers=4)

    # Resume training state.
    model = resnet_cifar.ResNet18()
    gpu = int(args.gpu)
    logger.info("Set GPU to {}".format(args.gpu))
    model = model.cuda(gpu)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                **config["optimizer"]["SGD"])
    scheduler = lr_scheduler.MultiStepLR(
        optimizer, **config["lr_scheduler"]["MultiStepLR"])
    resumed_epoch = resume_state(model, optimizer, args.resume, ckpt_dir,
                                 scheduler)

    # Set attack first and then add a normalized layer.
    pgd_config = {}
    for k, v in config["pgd_attack"].items():
        if k == "eps" or k == "alpha":
            pgd_config[k] = eval(v)
        else:
            pgd_config[k] = v
    attacker = PGD(model, **pgd_config)
    normalize_net = NormalizeByChannelMeanStd((0.4914, 0.4822, 0.4465),
                                              (0.2023, 0.1994, 0.2010))
    normalize_net.cuda(gpu)
    model = nn.Sequential(normalize_net, model)

    for epoch in range(config["num_epochs"] - resumed_epoch):
        logger.info("===Epoch: {}/{}===".format(epoch + resumed_epoch + 1,
                                                config["num_epochs"]))
        logger.info("Adversarial training...")
        adv_train_result = train(model,
                                 train_loader,
                                 criterion,
                                 optimizer,
                                 logger,
                                 attacker=attacker)
        if scheduler is not None:
            scheduler.step()
            logger.info("Adjust learning rate to {}".format(
                optimizer.param_groups[0]["lr"]))
        logger.info("Test model on clean data...")
        clean_test_result = test(model, test_loader, criterion, logger)
        logger.info("Test model on adversarial data...")
        adv_test_result = test(model,
                               test_loader,
                               criterion,
                               logger,
                               attacker=attacker)
        result = {
            "adv_train": adv_train_result,
            "clean_test": clean_test_result,
            "adv_test": adv_test_result,
        }

        # Save checkpoint
        saved_dict = {
            "epoch": epoch + resumed_epoch + 1,
            "result": result,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        if scheduler is not None:
            saved_dict["scheduler_state_dict"] = scheduler.state_dict()
        torch.save(
            saved_dict,
            os.path.join(ckpt_dir,
                         "epoch{}.pt".format(epoch + resumed_epoch + 1)),
        )
Esempio n. 4
0
def main(args):
    # Setting up an experiment
    config, params = setup(args, impute_args=False)

    # Setting up logger
    logger = get_logger(config['model_name'], config['dirs']['logs_dir'])

    # Extracting configurations
    data_config = config['data']
    logs_config = config['logs']
    training_config = config['training']
    sampling_config = config['sampling']
    dirs_config = config['dirs']
    logger.info('[SETUP] Experiment configurations')
    logger.info(
        f'[SETUP] Experiment directory: {os.path.abspath(dirs_config["exp_dir"])}'
    )

    # Loading the dataset
    (X_train, len_train), (X_valid, len_valid), (X_test, len_test) = load_data(
        data_config=data_config, step_size=training_config['num_pixels'])

    # Computing beat size in time steps
    beat_size = float(data_config['beat_resolution'] /
                      training_config['num_pixels'])

    # Creating the MultINN model
    tf.reset_default_graph()
    model = MultINN(config,
                    params,
                    mode=params['mode'],
                    name=config['model_name'])
    logger.info('[BUILT] Model')

    # Extracting placeholders, metrics and summaries
    placeholders = model.placeholders

    metrics, metrics_upd, summaries = model.metrics, model.metrics_upd, model.summaries
    loglik = metrics['log_likelihood']

    metrics_sum = summaries['metrics']

    # TensorFlow Session set up
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf.set_random_seed(training_config['random_seed'])
    np.random.seed(training_config['random_seed'])

    with tf.Session(config=tf_config) as sess:
        logger.info('[START] TF Session')
        with tf.variable_scope('init_global'):
            init_global = tf.global_variables_initializer()
        with tf.variable_scope('init_local'):
            init_local = tf.local_variables_initializer()
        sess.run([init_global, init_local])

        stats = TrainingStats()

        # Loading the model's weights or using initial weights
        if args.from_last:
            if model.load(sess, dirs_config['model_last_dir']):
                last_stats_file = os.path.join(dirs_config['model_last_dir'],
                                               'steps')
                if os.path.isfile(last_stats_file):
                    stats.load(last_stats_file)
                    logger.info('[LOAD]  Training stats file')

                logger.info(
                    f'[LOAD]  Pre-trained weights (last, epoch={stats.epoch})')
            else:
                logger.info('[LOAD]  Initial weights')
        elif model.load(sess, dirs_config['model_dir']):
            if os.path.isfile(dirs_config['model_stats_file']):
                stats.load(dirs_config['model_stats_file'])
                logger.info('[LOAD]  Training file')

            logger.info(
                f'[LOAD]  Pre-trained weights (best, epoch={stats.epoch})')
        else:
            logger.error('[ERROR]  No checkpoint found')
            raise ValueError('No checkpoint found')

        # Preparing to the evaluation
        writer_train = tf.summary.FileWriter(
            f'{dirs_config["logs_dir"]}/Graph/run_{stats.run}/train')
        writer_valid = tf.summary.FileWriter(
            f'{dirs_config["logs_dir"]}/Graph/run_{stats.run}/valid')
        writer_test = tf.summary.FileWriter(
            f'{dirs_config["logs_dir"]}/Graph/run_{stats.run}/test')

        batch_size = training_config['batch_size']
        piece_size = int(training_config['piece_size'] * beat_size)

        num_eval = X_valid.shape[0]

        collect_metrics(sess,
                        metrics_upd,
                        data=X_train[:num_eval, ...],
                        data_lengths=len_train[:num_eval, ...],
                        placeholders=placeholders,
                        batch_size=batch_size * 2,
                        piece_size=piece_size)
        summary, loglik_val = sess.run([metrics_sum, loglik])
        writer_train.add_summary(summary, stats.epoch)
        del summary
        logger.info(
            f'[EVAL]  Training   set log-likelihood:  {loglik_val:7.3f}')

        collect_metrics(sess,
                        metrics_upd,
                        data=X_valid,
                        data_lengths=len_valid,
                        placeholders=placeholders,
                        batch_size=batch_size * 2,
                        piece_size=piece_size)
        summary, loglik_val = sess.run([metrics_sum, loglik])
        writer_valid.add_summary(summary, stats.epoch)
        del summary
        logger.info(
            f'[EVAL]  Validation set log-likelihood:  {loglik_val:7.3f}')

        collect_metrics(sess,
                        metrics_upd,
                        data=X_test[:num_eval, ...],
                        data_lengths=len_test[:num_eval, ...],
                        placeholders=placeholders,
                        batch_size=batch_size * 2,
                        piece_size=piece_size)
        summary, loglik_val = sess.run([metrics_sum, loglik])
        writer_test.add_summary(summary, stats.epoch)
        del summary
        logger.info(
            f'[EVAL]  Test set       log-likelihood:  {loglik_val:7.3f}')
Esempio n. 5
0
def main(args):
    # Setting up an experiment
    config, params = setup(args, impute_args=False)

    # Setting up logger
    logger = get_logger(config['model_name'], config['dirs']['logs_dir'])

    # Extracting configurations
    data_config = config['data']
    logs_config = config['logs']
    training_config = config['training']
    sampling_config = config['sampling']
    dirs_config = config['dirs']
    logger.info('[SETUP] Experiment configurations')
    logger.info(f'[SETUP] Experiment directory: {os.path.abspath(dirs_config["exp_dir"])}')

    # Loading the dataset
    (X_train, len_train), (X_valid, len_valid), (_, _) = load_data(
        data_config=data_config,
        step_size=training_config['num_pixels']
    )

    # Computing beat size in time steps
    beat_size = float(data_config['beat_resolution'] / training_config['num_pixels'])

    # Preparing inputs for sampling
    intro_songs, save_ids, song_labels = prepare_sampling_inputs(X_train, X_valid, sampling_config, beat_size)
    num_save_intro = len(save_ids) // sampling_config['num_save']
    logger.info('[SETUP] Inputs for sampling')

    # Creating the MultINN model
    tf.reset_default_graph()
    model = MultINN(config, params, mode=params['mode'], name=config['model_name'])
    logger.info('[BUILT] Model')

    # Building the sampler and evaluator
    sampler = model.sampler(num_beats=sampling_config['sample_beats'])
    logger.info('[BUILT] Sampler')

    # Extracting placeholders, metrics and summaries
    placeholders = model.placeholders

    # TensorFlow Session set up
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf.set_random_seed(training_config['random_seed'])
    np.random.seed(training_config['random_seed'])

    with tf.Session(config=tf_config) as sess:
        logger.info('[START] TF Session')
        with tf.variable_scope('init_global'):
            init_global = tf.global_variables_initializer()
        with tf.variable_scope('init_local'):
            init_local = tf.local_variables_initializer()
        sess.run([init_global, init_local])

        stats = TrainingStats()

        # Loading the model's weights or using initial weights
        if args.from_last:
            if model.load(sess, dirs_config['model_last_dir']):
                last_stats_file = os.path.join(dirs_config['model_last_dir'], 'steps')
                if os.path.isfile(last_stats_file):
                    stats.load(last_stats_file)
                    logger.info('[LOAD]  Training stats file')

                logger.info(f'[LOAD]  Pre-trained weights (last, epoch={stats.epoch})')
            else:
                logger.info('[LOAD]  Initial weights')
        elif model.load(sess, dirs_config['model_dir']):
            if os.path.isfile(dirs_config['model_stats_file']):
                stats.load(dirs_config['model_stats_file'])
                logger.info('[LOAD]  Training file')

            logger.info(f'[LOAD]  Pre-trained weights (best, epoch={stats.epoch})')
        else:
            logger.error('[ERROR]  No checkpoint found')
            raise ValueError('No checkpoint found')

        samples = generate_music(sess, sampler, intro_songs, placeholders,
                                 num_songs=sampling_config['num_songs'])
        logger.info('[EVAL]  Generated music samples')

        samples = pad_to_midi(samples, data_config)
        samples_to_save = samples[save_ids]

        # Saving the music
        if logs_config['save_samples_epochs'] > 0 and stats.epoch % logs_config['save_samples_epochs'] == 0:
            save_music(samples_to_save, num_intro=num_save_intro, data_config=data_config,
                       base_path=f'eval_{model.name}_e{stats.epoch}', save_dir=dirs_config['samples_dir'],
                       song_labels=song_labels)
            logger.info('[SAVE]  Saved music samples')

        if args.eval_samples:
            logger.info('[EVAL]  Evaluating music samples')
            samples = np.reshape(
                samples,
                (samples.shape[0], -1, data_config['beat_resolution'] * 4,) + samples.shape[-2:]
            )
            compute_sample_metrics(samples)
Esempio n. 6
0
def main(_):
    """
    Starting point of the application
    """
    hvd.init()
    set_flags()
    params = parse_args(PARSER.parse_args())
    model_dir = prepare_model_dir(params)
    logger = get_logger(params)

    estimator = build_estimator(params, model_dir)

    dataset = Dataset(data_dir=params.data_dir,
                      batch_size=params.batch_size,
                      fold=params.crossvalidation_idx,
                      augment=params.augment,
                      gpu_id=hvd.rank(),
                      num_gpus=hvd.size(),
                      seed=params.seed)

    if 'train' in params.exec_mode:
        max_steps = params.max_steps // (1 if params.benchmark else hvd.size())
        hooks = [hvd.BroadcastGlobalVariablesHook(0),
                 TrainingHook(logger,
                              max_steps=max_steps,
                              log_every=params.log_every)]

        if params.benchmark and hvd.rank() == 0:
            hooks.append(ProfilingHook(logger,
                                       batch_size=params.batch_size,
                                       log_every=params.log_every,
                                       warmup_steps=params.warmup_steps,
                                       mode='train'))

        estimator.train(
            input_fn=dataset.train_fn,
            steps=max_steps,
            hooks=hooks)

    if 'evaluate' in params.exec_mode:
        if hvd.rank() == 0:
            results = estimator.evaluate(input_fn=dataset.eval_fn, steps=dataset.eval_size)
            logger.log(step=(),
                       data={"eval_ce_loss": float(results["eval_ce_loss"]),
                             "eval_dice_loss": float(results["eval_dice_loss"]),
                             "eval_total_loss": float(results["eval_total_loss"]),
                             "eval_dice_score": float(results["eval_dice_score"])})

    if 'predict' in params.exec_mode:
        if hvd.rank() == 0:
            predict_steps = dataset.test_size
            hooks = None
            if params.benchmark:
                hooks = [ProfilingHook(logger,
                                       batch_size=params.batch_size,
                                       log_every=params.log_every,
                                       warmup_steps=params.warmup_steps,
                                       mode="test")]
                predict_steps = params.warmup_steps * 2 * params.batch_size

            predictions = estimator.predict(
                input_fn=lambda: dataset.test_fn(count=math.ceil(predict_steps / dataset.test_size)),
                hooks=hooks)
            binary_masks = [np.argmax(p['logits'], axis=-1).astype(np.uint8) * 255 for p in predictions]

            if not params.benchmark:
                multipage_tif = [Image.fromarray(mask).resize(size=(512, 512), resample=Image.BILINEAR)
                                 for mask in binary_masks]

                output_dir = os.path.join(params.model_dir, 'pred')

                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)

                multipage_tif[0].save(os.path.join(output_dir, 'test-masks.tif'),
                                      compression="tiff_deflate",
                                      save_all=True,
                                      append_images=multipage_tif[1:])