예제 #1
0
    def init_offline(self, use_test_set, data_params, batch_size):
        self.is_online = False
        self.user_overwrite = False
        self.epsilon = 0 # Not used.

        # Load the two pertinent datasets into train_dataset and eval_dataset
        if use_test_set:
            train_dataset, eval_dataset = load_datasets('test', data_params)
        else:
            train_dataset, eval_dataset = load_datasets('dev', data_params)

        self.s_train, self.a_train, scores_train, h_train = train_dataset
        self.s_eval, self.a_eval, scores_test, h_test = eval_dataset

        # Compute the reward given scores and health. Currently, this just adds the two, weighting each one equally.
        self.r_train = np.add(scores_train, h_train)
        self.r_test = np.add(scores_test, h_test)

        self.batch_size = batch_size
예제 #2
0
    def test_gan_steps(self):
        self.skipTest('legacy')
        ds_train, ds_val, ds_info = data.load_datasets(self.args)
        gan = models.make_model(self.args, ds_info['channels'])

        img = next(iter(ds_train))
        disc_vals = gan.disc_grad(img)
        gen_vals = gan.gen_grad(img)

        self.assertIsInstance(disc_vals, dict)
        self.assertIsInstance(gen_vals, dict)
        for k, v in list(disc_vals.items()) + list(gen_vals.items()):
            tf.debugging.assert_shapes([(v, [])])
예제 #3
0
    def test_img_format(self):
        args = '--data=mnist --imsize=32  --bsz=8 '
        self.args = utils.parser.parse_args(args.split())
        utils.setup(self.args)

        ds_train, ds_val, ds_info = data.load_datasets(self.args)
        train_sample = next(iter(ds_train))
        val_sample = next(iter(ds_val))

        for sample in [train_sample, val_sample]:
            tf.debugging.assert_type(sample, tf.uint8)
            tf.debugging.assert_type(sample, tf.uint8)

            # Probablisitic asserts
            min_val, max_val = tf.reduce_min(sample), tf.reduce_max(sample)
            tf.debugging.assert_greater(max_val, 127 * tf.ones_like(max_val))
예제 #4
0
def run(args):
    # Setup
    strategy = setup(args)

    # Data
    ds_train, ds_val, ds_info = load_datasets(args)

    # Models
    with strategy.scope():
        model = make_model(args, ds_info['channels'])
        if args.model == 'gan':
            fid_model = fid.FID(args.debug)
        else:
            fid_model = None

    # Train
    train(args, model, ds_train, ds_val, ds_info, fid_model)
예제 #5
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    parser.add_argument('--target',
                        type=str,
                        default='vocals',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_new/')
    parser.add_argument('--output',
                        type=str,
                        default="../out_unmix/model_new_data_aug_tl",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    parser.add_argument('--model',
                        type=str,
                        help='Path to checkpoint folder',
                        default='umxhq')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0000000001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')
    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    # =============================================================================
    #     if args.model:
    #         scaler_mean = None
    #         scaler_std = None
    #
    #     else:
    # =============================================================================
    scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model.OpenUnmix(input_mean=scaler_mean,
                            input_scale=scaler_std,
                            nb_channels=args.nb_channels,
                            hidden_size=args.hidden_size,
                            n_fft=args.nfft,
                            n_hop=args.nhop,
                            max_bin=max_bin,
                            sample_rate=train_dataset.sample_rate).to(device)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        # disable progress bar
        err = io.StringIO()
        with redirect_stderr(err):
            unmix = torch.hub.load('sigsep/open-unmix-pytorch',
                                   'umxhq',
                                   target=args.target,
                                   device=device,
                                   pretrained=True)
# =============================================================================
#         model_path = Path(args.model).expanduser()
#         with open(Path(model_path, args.target + '.json'), 'r') as stream:
#             results = json.load(stream)
#
#         target_model_path = Path(model_path, args.target + ".chkpnt")
#         checkpoint = torch.load(target_model_path, map_location=device)
#         unmix.load_state_dict(checkpoint['state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer'])
#         scheduler.load_state_dict(checkpoint['scheduler'])
#         # train for another epochs_trained
#         t = tqdm.trange(
#             results['epochs_trained'],
#             results['epochs_trained'] + args.epochs + 1,
#             disable=args.quiet
#         )
#         train_losses = results['train_loss_history']
#         valid_losses = results['valid_loss_history']
#         train_times = results['train_time_history']
#         best_epoch = results['best_epoch']
#         es.best = results['best_loss']
#         es.num_bad_epochs = results['num_bad_epochs']
#     # else start from 0
# =============================================================================

    t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
    train_losses = []
    valid_losses = []
    train_times = []
    best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss = train(args, unmix, device, train_sampler, optimizer)
        valid_loss = valid(args, unmix, device, valid_sampler)
        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        from matplotlib import pyplot as plt

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 1)
        plt.title("Training loss")
        plt.plot(train_losses, label="Training")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "train_plot.pdf"))

        plt.figure(figsize=(16, 12))
        plt.subplot(2, 2, 2)
        plt.title("Validation loss")
        plt.plot(valid_losses, label="Validation")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        #plt.savefig(Path(target_path, "val_plot.pdf"))

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break
예제 #6
0
                    selected_derivs[i].y_toks) + "\n")
        out.close()
    """

def render_ratio(numer, denom):
    return "%i / %i = %.3f" % (numer, denom, float(numer)/denom)


if __name__ == '__main__':
    args = _parse_args()
    print(args)
    random.seed(args.seed)
    np.random.seed(args.seed)

    train, dev, test = data.load_datasets(
            args.train_path_input, args.train_path_output,
            args.dev_path_input, args.dev_path_output,
            args.test_path_input, args.test_path_output)
    train_data_indexed, \
            dev_data_indexed, \
            test_data_indexed, \
            input_indexer, \
            output_indexer = data.index_datasets(
                    train, dev, test, args.decoder_len_limit)
    '''
    print("{} train exs, {} dev exs, {} input types, {} output types".format(
        len(train_data_indexed),
        len(dev_data_indexed),
        len(input_indexer),
        len(output_indexer)))
    print("{} train exs, {} dev exs, {} input types, {} output types".format(
        len(train_data_indexed),
예제 #7
0
def main(_):

    start_time = datetime.now()
    tf.logging.info("Data path: {}".format(FLAGS.data_path))
    tf.logging.info("Output path: {}".format(FLAGS.output_path))
    tf.logging.info("Starting at: {}".format(start_time))
    tf.logging.info("Batch size: {} images per step".format(FLAGS.batch_size))

    last_epoch_start_time = start_time

    # Load datasets
    imgs_train, msks_train, imgs_test, msks_test = load_datasets(FLAGS)

    if not FLAGS.no_horovod:
        # Initialize Horovod.
        hvd.init()

    # Define model
    model = define_model(imgs_train.shape, msks_train.shape, FLAGS)

    if not FLAGS.no_horovod:
        # Horovod: adjust learning rate based on number workers
        opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learningrate,
                                     epsilon=tf.keras.backend.epsilon())
        #opt = tf.train.RMSPropOptimizer(0.0001 * hvd.size())
        # tf.logging.info("HOROVOD: New learning rate is {}".\
        #         format(FLAGS.learningrate * hvd.size()))
    else:
        opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learningrate,
                                     epsilon=tf.keras.backend.epsilon())
        #opt = tf.train.RMSPropOptimizer(0.0001)

    # Wrap optimizer with Horovod Distributed Optimizer.
    if not FLAGS.no_horovod:
        tf.logging.info("HOROVOD: Wrapped optimizer")
        opt = hvd.DistributedOptimizer(opt)

    global_step = tf.train.get_or_create_global_step()
    train_op = opt.minimize(model["loss"], global_step=global_step)

    train_length = len(imgs_train)
    total_steps = (FLAGS.epochs * train_length) // FLAGS.batch_size
    if not FLAGS.no_horovod:
        last_step = total_steps // hvd.size()
        validation_steps = train_length // FLAGS.batch_size // hvd.size()
    else:
        last_step = total_steps
        validation_steps = train_length // FLAGS.batch_size

    def formatter_log(tensors):
        """
        Format the log output
        """
        if FLAGS.no_horovod:
            logstring = "Step {} of {}: " \
               " training Dice loss = {:.4f}," \
               " training Dice = {:.4f}".format(tensors["step"],
               last_step,
               tensors["loss"], tensors["dice"])
        else:
            logstring = "HOROVOD (Worker #{}), Step {} of {}: " \
               " training Dice loss = {:.4f}," \
               " training Dice = {:.4f}".format(
               hvd.rank(),
               tensors["step"],
               last_step,
               tensors["loss"], tensors["dice"])

        return logstring

    hooks = [
        tf.train.StopAtStepHook(last_step=last_step),

        # Prints the loss and step every log_steps steps
        tf.train.LoggingTensorHook(tensors={
            "step": global_step,
            "loss": model["loss"],
            "dice": model["metric_dice"]
        },
                                   every_n_iter=FLAGS.log_steps,
                                   formatter=formatter_log),
    ]

    # Horovod: BroadcastGlobalVariablesHook broadcasts
    # initial variable states from rank 0 to all other
    # processes. This is necessary to ensure consistent
    # initialization of all workers when training is
    # started with random weights
    # or restored from a checkpoint.
    if not FLAGS.no_horovod:
        hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        # Horovod: save checkpoints only on worker 0 to prevent other workers from
        # corrupting them.
        if hvd.rank() == 0:
            checkpoint_dir = "{}/{}-workers/{}".format(
                FLAGS.output_path, hvd.size(),
                datetime.now().strftime("%Y%m%d-%H%M%S"))
            print(checkpoint_dir)
        else:
            checkpoint_dir = None

    else:
        checkpoint_dir = "{}/no_hvd/{}".format(
            FLAGS.output_path,
            datetime.now().strftime("%Y%m%d-%H%M%S"))

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint,
    # and closing when done or an error occurs.
    current_step = 0
    startidx = 0
    epoch_idx = 0

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=checkpoint_dir,
            hooks=hooks,
            save_summaries_steps=FLAGS.log_steps,
            log_step_count_steps=FLAGS.log_steps,
            config=config) as mon_sess:

        while not mon_sess.should_stop():

            # Run a training step synchronously.
            image_, mask_ = get_batch(imgs_train, msks_train, FLAGS.batch_size)

            # Do batch in order
            # stopidx = startidx + FLAGS.batch_size
            # if (stopidx > train_length):
            #     stopidx = train_length
            #
            # image_ = imgs_train[startidx:stopidx]
            # mask_  = msks_train[startidx:stopidx]

            mon_sess.run(train_op,
                         feed_dict={
                             model["input"]: image_,
                             model["label"]: mask_
                         })

            current_step += 1
            # # Get next batch (loop around if at end)
            # startidx += FLAGS.batch_size
            # if (startidx > train_length):
            #     startidx = 0

    stop_time = datetime.now()
    tf.logging.info("Stopping at: {}".format(stop_time))
    tf.logging.info("Elapsed time was: {}".format(stop_time - start_time))
예제 #8
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')
    # Loss parameters
    parser.add_argument('--loss',
                        type=str,
                        default="L2freq",
                        choices=[
                            'L2freq', 'L1freq', 'L2time', 'L1time', 'L2mask',
                            'L1mask', 'SISDRtime', 'SISDRfreq', 'MinSNRsdsdr',
                            'CrossEntropy', 'BinaryCrossEntropy', 'LogL2time',
                            'LogL1time', 'LogL2freq', 'LogL1freq', 'PSA',
                            'SNRPSA', 'Dissimilarity'
                        ],
                        help='kind of loss used during training')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="musdb",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')

    parser.add_argument('--root', type=str, help='root path of dataset')
    parser.add_argument('--output',
                        type=str,
                        default="open-unmix",
                        help='provide output path base folder name')
    parser.add_argument('--model', type=str, help='Path to checkpoint folder')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--reduce-samples',
                        type=int,
                        default=1,
                        help="reduce training samples by factor n")

    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')
    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=0,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    train_dataset, valid_dataset, args = data.load_datasets(parser, args)

    num_train = len(train_dataset)
    indices = list(range(num_train))

    # shuffle train indices once and for all
    np.random.seed(args.seed)
    np.random.shuffle(indices)

    if args.reduce_samples > 1:
        split = int(np.floor(num_train / args.reduce_samples))
        train_idx = indices[:split]
    else:
        train_idx = indices
    sampler = SubsetRandomSampler(train_idx)
    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                sampler=sampler,
                                                **dataloader_kwargs)

    stats_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=1,
                                                sampler=sampler,
                                                **dataloader_kwargs)

    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    if args.model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, stats_sampler)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)
    # SNRPSA: de-compress the scaler in order to avoid an exploding gradient from the  uncompressed initial statistics
    if args.loss == 'SNRPSA':
        power = 2
    else:
        power = 1

    unmix = model.OpenUnmixSingle(
        n_fft=4096,
        n_hop=1024,
        input_is_spectrogram=False,
        hidden_size=args.hidden_size,
        nb_channels=args.nb_channels,
        sample_rate=train_dataset.sample_rate,
        nb_layers=3,
        input_mean=scaler_mean,
        input_scale=scaler_std,
        max_bin=max_bin,
        unidirectional=args.unidirectional,
        power=power,
    ).to(device)
    print('learning rate:')
    print(args.lr)
    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        print('LOADING MODEL')
        model_path = Path(args.model).expanduser()
        with open(Path(model_path,
                       str(len(args.targets)) + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, "model.chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)
        unmix.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        # train for another epochs_trained
        t = tqdm.trange(results['epochs_trained'],
                        results['epochs_trained'] + args.epochs + 1,
                        disable=args.quiet)
        train_losses = results['train_loss_history']
        valid_losses = results['valid_loss_history']
        train_times = results['train_time_history']
        best_epoch = results['best_epoch']
        es.best = results['best_loss']
        es.num_bad_epochs = results['num_bad_epochs']
        print('Model loaded')
    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        valid_losses = []
        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss = train(args, unmix, device, train_sampler, optimizer)
        valid_loss = valid(args, unmix, device, valid_sampler)
        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
        )

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path,
                       str(len(args.targets)) + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break
예제 #9
0
if __name__ == "__main__":
    args = parser.parse_args()

    year_to_index = OrderedDict()
    model_to_index = OrderedDict()

    # dictionary from "years to obliteration" to number of samples
    dataset_sizes = {}
    # mapping from (year, model name) pairs to AUC
    aucs_dict = {}
    # consider each output target of years until AVM obliteration
    for obliteration_years in reversed(
            range(args.min_obliteration_years,
                  args.max_obliteration_years + 1)):
        print("\n### Obliteration Years = %d\n\n" % obliteration_years)
        dataset, full_df = data.load_datasets(
            obliteration_years=obliteration_years)
        X = dataset["X"][:args.num_samples]
        Y = dataset["Y"][:args.num_samples]
        dataset_sizes[obliteration_years] = len(X)
        assert len(X) == len(Y)
        VRAS = dataset["VRAS"][:args.num_samples]
        FP = dataset["FP"][:args.num_samples]
        SM = dataset["SM"][:args.num_samples]
        Y_binary = Y == 2
        grid = hyperparameter_grid(
            logistic_regression=args.logistic_regression,
            svm=args.svm,
            gradient_boosting=args.gradient_boosting,
            extra_trees=args.extra_trees,
            random_forest=args.random_forest)
        if len(grid) == 0:
예제 #10
0
            )

sess.run(tf.global_variables_initializer())

# graph = tf.get_default_graph()
# for var in tf.global_variables():
#     print(var)
# print([var for var in tf.all_variables()])
# graph.get_tensor_by_name('X')

# Load data
def get_data_params():
    return {
        "data_dir": './data/data_053017/',
        "num_images": 1000,
        "width": 64,
        "height": 48,
        "multi_frame_state": False,
        "frames_per_state": 1,
        "actions": ACTIONS,
        "eval_proportion": .5,
        "image_size": 28,
    }

all_data, _ = load_datasets("test", get_data_params())
s, a, scores, h = all_data

print("##### SALIENCY MAPS #######################################")
# Generate 5 options
show_saliency_maps(foxnet, s, a)
예제 #11
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    # =============================================================================
    #     parser.add_argument('--target', type=str, default='vocals',
    #                         help='target source (will be passed to the dataset)')
    #
    # =============================================================================
    parser.add_argument('--target',
                        type=str,
                        default='tabla',
                        help='target source (will be passed to the dataset)')

    # Dataset paramaters
    parser.add_argument('--dataset',
                        type=str,
                        default="aligned",
                        choices=[
                            'musdb', 'aligned', 'sourcefolder',
                            'trackfolder_var', 'trackfolder_fix'
                        ],
                        help='Name of the dataset.')
    parser.add_argument('--root',
                        type=str,
                        help='root path of dataset',
                        default='../rec_data_final/')
    parser.add_argument('--output',
                        type=str,
                        default="../new_models/model_tabla_mtl_ourmix_1",
                        help='provide output path base folder name')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_mse_pretrain1')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default="../out_unmix/model_new_data_aug_tabla_mse_pretrain8" )
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='../out_unmix/model_new_data_aug_tabla_bce_finetune2')
    parser.add_argument('--model', type=str, help='Path to checkpoint folder')
    #parser.add_argument('--model', type=str, help='Path to checkpoint folder' , default='umxhq')
    parser.add_argument(
        '--onset-model',
        type=str,
        help='Path to onset detection model weights',
        default=
        "/media/Sharedata/rohit/cnn-onset-det/models/apr4/saved_model_0_80mel-0-16000_1ch_44100.pt"
    )

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument(
        '--patience',
        type=int,
        default=140,
        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience',
                        type=int,
                        default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma',
                        type=float,
                        default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weighting of different loss components')
    parser.add_argument(
        '--finetune',
        type=int,
        default=0,
        help=
        'If true(1), then optimiser states from checkpoint model are reset (required for bce finetuning), false if aim is to resume training from where it was left off'
    )
    parser.add_argument('--onset-thresh',
                        type=float,
                        default=0.3,
                        help='Threshold above which onset is said to occur')
    parser.add_argument(
        '--binarise',
        type=int,
        default=0,
        help=
        'If=1(true), then target novelty function is made binary, if=0(false), then left as it is'
    )
    parser.add_argument(
        '--onset-trainable',
        type=int,
        default=0,
        help=
        'If=1(true), then onsetCNN will also get trained in finetuning stage, if=0(false) then kept fixed'
    )

    # Model Parameters
    parser.add_argument('--seq-dur',
                        type=float,
                        default=6.0,
                        help='Sequence duration in seconds'
                        'value of <=0.0 will use full/variable length')
    parser.add_argument(
        '--unidirectional',
        action='store_true',
        default=False,
        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft',
                        type=int,
                        default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024, help='STFT hop size')

    # =============================================================================
    #     parser.add_argument('--nfft', type=int, default=2048,
    #                         help='STFT fft size and window size')
    #     parser.add_argument('--nhop', type=int, default=512,
    #                         help='STFT hop size')
    # =============================================================================

    parser.add_argument('--n-mels',
                        type=int,
                        default=80,
                        help='Number of bins in mel spectrogram')

    parser.add_argument(
        '--hidden-size',
        type=int,
        default=512,
        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth',
                        type=int,
                        default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels',
                        type=int,
                        default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers',
                        type=int,
                        default=4,
                        help='Number of workers for dataloader.')

    # Misc Parameters
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {
        'num_workers': args.nb_workers,
        'pin_memory': True
    } if use_cuda else {}

    repo_dir = os.path.abspath(os.path.dirname(__file__))
    repo = Repo(repo_dir)
    commit = repo.head.commit.hexsha[:7]

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.autograd.set_detect_anomaly(True)

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)
    print("TRAIN DATASET", train_dataset)
    print("VALID DATASET", valid_dataset)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)

    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    if args.model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = model_mtl.OpenUnmix_mtl(
        input_mean=scaler_mean,
        input_scale=scaler_std,
        nb_channels=args.nb_channels,
        hidden_size=args.hidden_size,
        n_fft=args.nfft,
        n_hop=args.nhop,
        max_bin=max_bin,
        sample_rate=train_dataset.sample_rate).to(device)

    #Read trained onset detection network (Model through which target spectrogram is passed)
    detect_onset = model.onsetCNN().to(device)
    detect_onset.load_state_dict(
        torch.load(args.onset_model, map_location='cuda:0'))

    #Model through which separated output is passed
    # detect_onset_training = model.onsetCNN().to(device)
    # detect_onset_training.load_state_dict(torch.load(args.onset_model, map_location='cuda:0'))

    for child in detect_onset.children():
        for param in child.parameters():
            param.requires_grad = False

    #If onset trainable is false, then we want to keep the weights of this moel fixed
    # if (args.onset_trainable == 0):
    #     for child in detect_onset_training.children():
    #         for param in child.parameters():
    #             param.requires_grad = False

    # #FOR CHECKING, REMOVE LATER
    # for child in detect_onset_training.children():
    #     for param in child.parameters():
    #         print(param.requires_grad)

    optimizer = torch.optim.Adam(unmix.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.model:
        model_path = Path(args.model).expanduser()
        with open(Path(model_path, args.target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, args.target + ".chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)
        unmix.load_state_dict(checkpoint['state_dict'])

        #Only when onse is trainable and when that finetuning is being resumed from a point where it is left off, then read the onset state_dict
        # if ((args.onset_trainable==1)and(args.finetune==0)):
        #     detect_onset_training.load_state_dict(checkpoint['onset_state_dict'])
        #     print("Reading saved onset model")
        # else:
        #     print("Not reading saved onset model")

        if (args.finetune == 0):
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            # train for another epochs_trained
            t = tqdm.trange(results['epochs_trained'],
                            results['epochs_trained'] + args.epochs + 1,
                            disable=args.quiet)
            print("PICKUP WHERE LEFT OFF", args.finetune)
            train_losses = results['train_loss_history']
            train_mse_losses = results['train_mse_loss_history']
            train_bce_losses = results['train_bce_loss_history']
            valid_losses = results['valid_loss_history']
            valid_mse_losses = results['valid_mse_loss_history']
            valid_bce_losses = results['valid_bce_loss_history']
            train_times = results['train_time_history']
            best_epoch = results['best_epoch']

            es.best = results['best_loss']
            es.num_bad_epochs = results['num_bad_epochs']

        else:
            t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
            train_losses = []
            train_mse_losses = []
            train_bce_losses = []
            print("NOT PICKUP WHERE LEFT OFF", args.finetune)
            valid_losses = []
            valid_mse_losses = []
            valid_bce_losses = []

            train_times = []
            best_epoch = 0

        #es.best = results['best_loss']
        #es.num_bad_epochs = results['num_bad_epochs']
    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        train_mse_losses = []
        train_bce_losses = []

        valid_losses = []
        valid_mse_losses = []
        valid_bce_losses = []

        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()
        train_loss, train_mse_loss, train_bce_loss = train(
            args,
            unmix,
            device,
            train_sampler,
            optimizer,
            detect_onset=detect_onset)
        #train_mse_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[1]
        #train_bce_loss = train(args, unmix, device, train_sampler, optimizer, detect_onset=detect_onset)[2]

        valid_loss, valid_mse_loss, valid_bce_loss = valid(
            args, unmix, device, valid_sampler, detect_onset=detect_onset)
        #valid_mse_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[1]
        #valid_bce_loss = valid(args, unmix, device, valid_sampler, detect_onset=detect_onset)[2]

        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        train_mse_losses.append(train_mse_loss)
        train_bce_losses.append(train_bce_loss)

        valid_losses.append(valid_loss)
        valid_mse_losses.append(valid_mse_loss)
        valid_bce_losses.append(valid_bce_loss)

        t.set_postfix(train_loss=train_loss, val_loss=valid_loss)

        stop = es.step(valid_loss)

        #from matplotlib import pyplot as plt

        # =============================================================================
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 1)
        #         plt.title("Training loss")
        #         plt.plot(train_losses,label="Training")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "train_plot.pdf"))
        #
        #         plt.figure(figsize=(16,12))
        #         plt.subplot(2, 2, 2)
        #         plt.title("Validation loss")
        #         plt.plot(valid_losses,label="Validation")
        #         plt.xlabel("Iterations")
        #         plt.ylabel("Loss")
        #         plt.legend()
        #         plt.show()
        #         #plt.savefig(Path(target_path, "val_plot.pdf"))
        # =============================================================================

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': unmix.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'onset_state_dict': detect_onset.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target)

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'train_mse_loss_history': train_mse_losses,
            'train_bce_loss_history': train_bce_losses,
            'valid_loss_history': valid_losses,
            'valid_mse_loss_history': valid_mse_losses,
            'valid_bce_loss_history': valid_bce_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs,
            'commit': commit
        }

        with open(Path(target_path, args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break


# =============================================================================
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 1)
#     plt.title("Training loss")
#     #plt.plot(train_losses,label="Training")
#     plt.plot(train_losses,label="Training")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     #plt.show()
#
#     plt.figure(figsize=(16,12))
#     plt.subplot(2, 2, 2)
#     plt.title("Validation loss")
#     plt.plot(valid_losses,label="Validation")
#     plt.xlabel("Iterations")
#     plt.ylabel("Loss")
#     plt.legend()
#     plt.show()
#     plt.savefig(Path(target_path, "train_val_plot.pdf"))
#     #plt.savefig(Path(target_path, "train_plot.pdf"))
# =============================================================================

    print("TRAINING DONE!!")

    plt.figure()
    plt.title("Training loss")
    plt.plot(train_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_plot.pdf"))

    plt.figure()
    plt.title("Validation loss")
    plt.plot(valid_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_plot.pdf"))

    plt.figure()
    plt.title("Training BCE loss")
    plt.plot(train_bce_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_bce_plot.pdf"))

    plt.figure()
    plt.title("Validation BCE loss")
    plt.plot(valid_bce_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_bce_plot.pdf"))

    plt.figure()
    plt.title("Training MSE loss")
    plt.plot(train_mse_losses, label="Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "train_mse_plot.pdf"))

    plt.figure()
    plt.title("Validation MSE loss")
    plt.plot(valid_mse_losses, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(Path(target_path, "val_mse_plot.pdf"))
예제 #12
0
esv = read_esv(args.data, args.data_type)
# import pdb; pdb.set_trace()
train_path = os.path.join(
    'data',
    os.path.join(args.data_type,
                 f'{args.translation}_{args.data_type}_train.csv'))
dev_path = os.path.join(
    'data',
    os.path.join(args.data_type,
                 f'{args.translation}_{args.data_type}_dev.csv'))
test_path = os.path.join(
    'data',
    os.path.join(args.data_type,
                 f'{args.translation}_{args.data_type}_test.csv'))
train, dev, test = load_datasets(train_path, dev_path, test_path, esv)
train_data_indexed, dev_data_indexed, test_data_indexed, indexer = index_dataset(
    esv, train, dev, test, args.bptt)

train_data_indexed.sort(key=lambda data: len(data), reverse=True)
dev_data_indexed.sort(key=lambda data: len(data), reverse=True)
test_data_indexed.sort(key=lambda data: len(data), reverse=True)

# corpus = data.Corpus(args.data)

# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
예제 #13
0
    axes.set_ylim(0, 1)

    # axes.legend()
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Time Horizon = %d Years' % obliteration_years)
    figure = axes.figure
    figure.savefig(filename, dpi=dpi)
    print("\n >>> Average AUC score across all bootstrap samples: %0.4f" % (
        np.mean(auc_scores),))
    return models


if __name__ == "__main__":
    args = parser.parse_args()
    dataset, full_df = data.load_datasets(
        obliteration_years=args.obliteration_years)
    X = dataset["X"]
    Y = dataset["Y"]
    assert len(X) == len(Y)
    columns = dataset["df_filtered"].columns
    VRAS = dataset["VRAS"]
    FP = dataset["FP"]
    SM = dataset["SM"]

    Y_binary = Y == 2
    lr_hyperparameters = list(
        hyperparameter_grid(logistic_regression=True).values())[0]
    print(lr_hyperparameters)
    all_models = generate_roc_plot(
        X,
        Y_binary,
예제 #14
0
def main():
    parser = argparse.ArgumentParser(description='Open Unmix Trainer')

    # which target do we want to train?
    parser.add_argument('--target', type=str, default='vocals',
                        help='target source (will be passed to the dataset)')

    # experiment tag which will determine output folder in trained models, tensorboard name, etc.
    parser.add_argument('--tag', type=str)


    # allow to pass a comment about the experiment
    parser.add_argument('--comment', type=str, help='comment about the experiment')

    args, _ = parser.parse_known_args()

    # Dataset paramaters
    parser.add_argument('--dataset', type=str, default="musdb",
                        choices=[
                            'musdb_lyrics', 'timit_music', 'blended', 'nus', 'nus_train'
                        ],
                        help='Name of the dataset.')

    parser.add_argument('--root', type=str, help='root path of dataset')
    parser.add_argument('--output', type=str, default="trained_models/{}/".format(args.tag),
                        help='provide output path base folder name')

    parser.add_argument('--wst-model', type=str, help='Path to checkpoint folder for warmstart')

    # Trainig Parameters
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate, defaults to 1e-3')
    parser.add_argument('--patience', type=int, default=140,
                        help='maximum number of epochs to train (default: 140)')
    parser.add_argument('--lr-decay-patience', type=int, default=80,
                        help='lr decay patience for plateau scheduler')
    parser.add_argument('--lr-decay-gamma', type=float, default=0.3,
                        help='gamma of learning rate scheduler decay')
    parser.add_argument('--weight-decay', type=float, default=0.00001,
                        help='weight decay')
    parser.add_argument('--seed', type=int, default=0, metavar='S',
                        help='random seed (default: 0)')

    parser.add_argument('--alignment-from', type=str, default=None)
    parser.add_argument('--fake-alignment', action='store_true', default=False)


    # Model Parameters
    parser.add_argument('--unidirectional', action='store_true', default=False,
                        help='Use unidirectional LSTM instead of bidirectional')
    parser.add_argument('--nfft', type=int, default=4096,
                        help='STFT fft size and window size')
    parser.add_argument('--nhop', type=int, default=1024,
                        help='STFT hop size')
    parser.add_argument('--hidden-size', type=int, default=512,
                        help='hidden size parameter of dense bottleneck layers')
    parser.add_argument('--bandwidth', type=int, default=16000,
                        help='maximum model bandwidth in herz')
    parser.add_argument('--nb-channels', type=int, default=2,
                        help='set number of channels for model (1, 2)')
    parser.add_argument('--nb-workers', type=int, default=0,
                        help='Number of workers for dataloader.')
    parser.add_argument('--nb-audio-encoder-layers', type=int, default=2)
    parser.add_argument('--nb-layers', type=int, default=3)
    # name of the model class in model.py that should be used
    parser.add_argument('--architecture', type=str)
    # select attention type if applicable for selected model
    parser.add_argument('--attention', type=str)

    # Misc Parameters
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='less verbose during training')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')

    args, _ = parser.parse_known_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    print("Using GPU:", use_cuda)
    print("Using Torchaudio: ", utils._torchaudio_available())
    dataloader_kwargs = {'num_workers': args.nb_workers, 'pin_memory': True} if use_cuda else {}

    writer = SummaryWriter(logdir=os.path.join('tensorboard', args.tag))

    # use jpg or npy
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if use_cuda else "cpu")

    train_dataset, valid_dataset, args = data.load_datasets(parser, args)

    # create output dir if not exist
    target_path = Path(args.output)
    target_path.mkdir(parents=True, exist_ok=True)


    train_sampler = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=data.collate_fn, drop_last=True,
        **dataloader_kwargs
    )
    valid_sampler = torch.utils.data.DataLoader(
        valid_dataset, batch_size=1, collate_fn=data.collate_fn, **dataloader_kwargs
    )

    if args.wst_model:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = utils.bandwidth_to_max_bin(
        valid_dataset.sample_rate, args.nfft, args.bandwidth
    )

    train_args_dict = vars(args)
    train_args_dict['max_bin'] = int(max_bin)  # added to config
    train_args_dict['vocabulary_size'] = valid_dataset.vocabulary_size  # added to config

    train_params_dict = copy.deepcopy(vars(args))  # return args as dictionary with no influence on args

    # add to parameters for model loading but not to config file
    train_params_dict['scaler_mean'] = scaler_mean
    train_params_dict['scaler_std'] = scaler_std

    model_class = model_utls.ModelLoader.get_model(args.architecture)
    model_to_train = model_class.from_config(train_params_dict)
    model_to_train.to(device)

    optimizer = torch.optim.Adam(
        model_to_train.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10
    )

    es = utils.EarlyStopping(patience=args.patience)

    # if a model is specified: resume training
    if args.wst_model:
        model_path = Path(os.path.join('trained_models', args.wst_model)).expanduser()
        with open(Path(model_path, args.target + '.json'), 'r') as stream:
            results = json.load(stream)

        target_model_path = Path(model_path, args.target + ".chkpnt")
        checkpoint = torch.load(target_model_path, map_location=device)


        model_to_train.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        # train for another arg.epochs
        t = tqdm.trange(
            results['epochs_trained'],
            results['epochs_trained'] + args.epochs + 1,
            disable=args.quiet
        )
        train_losses = results['train_loss_history']
        valid_losses = results['valid_loss_history']
        train_times = results['train_time_history']
        best_epoch = 0

    # else start from 0
    else:
        t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
        train_losses = []
        valid_losses = []
        train_times = []
        best_epoch = 0

    for epoch in t:
        t.set_description("Training Epoch")
        end = time.time()

        train_loss = train(args, model_to_train, device, train_sampler, optimizer)
        #valid_loss, sdr_val, sar_val, sir_val = valid(args, model_to_train, device, valid_sampler)
        valid_loss = valid(args, model_to_train, device, valid_sampler)

        writer.add_scalar("Training_cost", train_loss, epoch)
        writer.add_scalar("Validation_cost", valid_loss, epoch)

        scheduler.step(valid_loss)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        t.set_postfix(
            train_loss=train_loss, val_loss=valid_loss
        )

        stop = es.step(valid_loss)

        if valid_loss == es.best:
            best_epoch = epoch

        utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model_to_train.state_dict(),
                'best_loss': es.best,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            },
            is_best=valid_loss == es.best,
            path=target_path,
            target=args.target
        )

        # save params
        params = {
            'epochs_trained': epoch,
            'args': vars(args),
            'best_loss': es.best,
            'best_epoch': best_epoch,
            'train_loss_history': train_losses,
            'valid_loss_history': valid_losses,
            'train_time_history': train_times,
            'num_bad_epochs': es.num_bad_epochs
        }

        with open(Path(target_path,  args.target + '.json'), 'w') as outfile:
            outfile.write(json.dumps(params, indent=4, sort_keys=True))

        train_times.append(time.time() - end)

        if stop:
            print("Apply Early Stopping")
            break
예제 #15
0
    axes.set_xlim(0, 1)
    axes.set_ylim(0, 1)

    # axes.legend()
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Time Horizon = %d Years' % obliteration_years)
    figure = axes.figure
    figure.savefig(filename, dpi=dpi)
    return best_model

if __name__ == "__main__":
    args = parser.parse_args()

    train_dataset, train_df = data.load_datasets(
        filename="AVM.xlsx",
        obliteration_years=args.obliteration_years)
    test_dataset, test_df = data.load_datasets(
        filename="AVM_NYU.xlsx",
        obliteration_years=args.obliteration_years)
    X_train = train_dataset["X"]
    Y_train = train_dataset["Y"]

    X_test = test_dataset["X"]
    Y_test = test_dataset["Y"]

    test_columns = test_dataset["df_filtered"].columns
    VRAS_test = test_dataset["VRAS"]
    FP_test = test_dataset["FP"]
    SM_test = test_dataset["SM"]
예제 #16
0

device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(
    r"checkpoints/lr1e-5/masknet-epoch=37-val_loss=0.04.ckpt", map_location=device,
)


state_dict = checkpoint["state_dict"]
state_dict = {k.partition("_model.")[2]: v for k, v in state_dict.items()}

model = MaskNet()
model.load_state_dict(state_dict)
model.eval()

train_dataset, val_dataset, test_dataset = load_datasets(r"dataset")
testloader = DataLoader(test_dataset, batch_size=64)


correct = 0
total = 0
prediction_labels = []
true_labels = []

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()