Пример #1
0
def _testOscil(oscil=Sine, num_iterations=1,
              do_plot=False, *args, **kwargs):
    if num_iterations < 1:
        num_iterations = 1
        warning.warn('number of oscillator iterations is less than 1; setting'\
                'to {}'.format(num_iterations))

    # generate samples
    samples = None
    for i, s in enumerate(oscil(*args, **kwargs), 1):
        if samples is None:
            samples = s
        else:
            samples = np.append(samples, s)

        if i >= num_iterations:
            break

    # output
    if do_plot:
        msg = '(iterations={})'
        msg = msg.format(num_iterations)
        plot_samples(samples, width=len(samples), msg=msg)
    else:
        print(samples)
Пример #2
0
def main():
    import plot
    dur = 999

    print("Note: The plots may look a little funky, since the areas outside"
          " the width of the envelopes are represented by 1's. This is normal.")
    attack = linear_envelope(0, .9, dur*.2), "attack"
    decay = linear_envelope(.9, .8, dur*.1), "decay"
    sustain = linear_envelope(.8, .8, dur*.5), "sustain"
    release = linear_envelope(.8, 0, dur*.2), "release"
    for envelope, msg in [attack, decay, sustain, release]:
        for chunk in envelope:
            plot.plot_samples(chunk, msg='linear ' + msg + ' envelope')
Пример #3
0
def _do_op(op, op_name, fill, waves, do_plot_input=False):
    '''Performs op on waves.

    Args:
        op - function that acts on waves, must be callable with a single arg
        op_name - name of op function, for plotting
        waves - iterable of numpy ndarrays for each wave to operate on
        do_plot_input - if true, all waves are plotted prior to yield
    Returns:
        wave resulting from operation
    '''
    for wave_chunks in itertools.zip_longest(*waves, fillvalue=fill):
        if do_plot_input:
            msg = 'Plot of mixer.{}() (#/waves = {})'.format(
                op_name, len(wave_chunks))
            plot_samples(wave_chunks, len(wave_chunks[0]), msg=msg)
        yield op(wave_chunks)
Пример #4
0
def test_blocking(do_plot=False, instrument=None):
    import time
    import oscillator as oscil

    if do_plot:
        # stitch several chunks together and plot it
        from plot import plot_samples
        song = []
        for chunk in test_get_wave(DEF_SAMPLE_RATE, instrument):
            song.extend(chunk)
        plot_samples(song, width=2500, msg="testing wave instead of playing")
    else:
        # play stream
        p, stream = open_stream(DEF_SAMPLE_RATE)
        for chunk in test_get_wave(DEF_SAMPLE_RATE, instrument):
            chunk = chunk.astype(numpy.float32).tostring() # muy importante!
            stream.write(chunk)
        close_stream(p, stream)
Пример #5
0
def _testAdd(do_plot=False, plot_input=False):
    plot_output = not plot_input

    # Use additive synthesis to create a pseudo-square wave using sine waves.
    hz = 20
    oscil.set_chunk_size(5000)
    overtone_count = range(1, 4)
    fundamental = oscil.Sine(hz=hz)
    overtones = (oscil.Sine(hz=(hz * ((2 * n) + 1)), vol=(1 / ((2 * n) + 1)))
                 for n in overtone_count)
    harmonics = itertools.chain((fundamental,), overtones)
    square_tone = add(harmonics, plot_input)

    for num_iters, samples in enumerate(square_tone):
        if do_plot:
            if plot_output:
                msg = 'Testing mixer.add() (n_harmonics={}, fundamental_hz={})'
                msg = msg.format(overtone_count, hz)
                plot_samples(samples, len(samples), msg=msg)
        else: # show output
            print(samples)

        if num_iters >= 2:
            break
Пример #6
0
def misgan_impute(args,
                  data_gen,
                  mask_gen,
                  imputer,
                  data_critic,
                  mask_critic,
                  impu_critic,
                  data,
                  output_dir,
                  checkpoint=None):
    n_critic = args.n_critic
    gp_lambda = args.gp_lambda
    batch_size = args.batch_size
    nz = args.n_latent
    epochs = args.epoch
    plot_interval = args.plot_interval
    save_model_interval = args.save_interval
    alpha = args.alpha
    beta = args.beta
    gamma = args.gamma
    tau = args.tau
    update_all_networks = not args.imputeronly

    gen_data_dir = mkdir(output_dir / 'img')
    gen_mask_dir = mkdir(output_dir / 'mask')
    impute_dir = mkdir(output_dir / 'impute')
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    data_loader = DataLoader(data,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             num_workers=args.workers)
    n_batch = len(data_loader)
    data_shape = data[0][0].shape

    data_noise = torch.FloatTensor(batch_size, nz).to(device)
    mask_noise = torch.FloatTensor(batch_size, nz).to(device)
    impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device)

    # Interpolation coefficient
    eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)

    # For computing gradient penalty
    ones = torch.ones(batch_size).to(device)

    lrate = 1e-4
    imputer_lrate = 2e-4
    data_gen_optimizer = optim.Adam(data_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))
    mask_gen_optimizer = optim.Adam(mask_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))
    imputer_optimizer = optim.Adam(imputer.parameters(),
                                   lr=imputer_lrate,
                                   betas=(.5, .9))

    data_critic_optimizer = optim.Adam(data_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))
    mask_critic_optimizer = optim.Adam(mask_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))
    impu_critic_optimizer = optim.Adam(impu_critic.parameters(),
                                       lr=imputer_lrate,
                                       betas=(.5, .9))

    update_data_critic = CriticUpdater(data_critic, data_critic_optimizer, eps,
                                       ones, gp_lambda)
    update_mask_critic = CriticUpdater(mask_critic, mask_critic_optimizer, eps,
                                       ones, gp_lambda)
    update_impu_critic = CriticUpdater(impu_critic, impu_critic_optimizer, eps,
                                       ones, gp_lambda)

    start_epoch = 0
    critic_updates = 0
    log = defaultdict(list)

    if args.resume:
        data_gen.load_state_dict(checkpoint['data_gen'])
        mask_gen.load_state_dict(checkpoint['mask_gen'])
        imputer.load_state_dict(checkpoint['imputer'])
        data_critic.load_state_dict(checkpoint['data_critic'])
        mask_critic.load_state_dict(checkpoint['mask_critic'])
        impu_critic.load_state_dict(checkpoint['impu_critic'])
        data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
        mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
        imputer_optimizer.load_state_dict(checkpoint['imputer_opt'])
        data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
        mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
        impu_critic_optimizer.load_state_dict(checkpoint['impu_critic_opt'])
        start_epoch = checkpoint['epoch']
        critic_updates = checkpoint['critic_updates']
        log = checkpoint['log']
    elif args.pretrain:
        pretrain = torch.load(args.pretrain, map_location='cpu')
        data_gen.load_state_dict(pretrain['data_gen'])
        mask_gen.load_state_dict(pretrain['mask_gen'])
        data_critic.load_state_dict(pretrain['data_critic'])
        mask_critic.load_state_dict(pretrain['mask_critic'])
        if 'imputer' in pretrain:
            imputer.load_state_dict(pretrain['imputer'])
            impu_critic.load_state_dict(pretrain['impu_critic'])

    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)

    def save_model(path, epoch, critic_updates=0):
        torch.save(
            {
                'data_gen': data_gen.state_dict(),
                'mask_gen': mask_gen.state_dict(),
                'imputer': imputer.state_dict(),
                'data_critic': data_critic.state_dict(),
                'mask_critic': mask_critic.state_dict(),
                'impu_critic': impu_critic.state_dict(),
                'data_gen_opt': data_gen_optimizer.state_dict(),
                'mask_gen_opt': mask_gen_optimizer.state_dict(),
                'imputer_opt': imputer_optimizer.state_dict(),
                'data_critic_opt': data_critic_optimizer.state_dict(),
                'mask_critic_opt': mask_critic_optimizer.state_dict(),
                'impu_critic_opt': impu_critic_optimizer.state_dict(),
                'epoch': epoch + 1,
                'critic_updates': critic_updates,
                'log': log,
                'args': args,
            }, str(path))

    sns.set()
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        sum_data_loss, sum_mask_loss, sum_impu_loss = 0, 0, 0
        for real_data, real_mask, _, index in data_loader:
            # Assume real_data and real_mask have the same number of channels.
            # Could be modified to handle multi-channel images and
            # single-channel masks.
            real_mask = real_mask.float()[:, None]

            real_data = real_data.to(device)
            real_mask = real_mask.to(device)

            masked_real_data = mask_data(real_data, real_mask, tau)

            # Update discriminators' parameters
            data_noise.normal_()
            fake_data = data_gen(data_noise)

            impu_noise.uniform_()
            imputed_data = imputer(real_data, real_mask, impu_noise)
            masked_imputed_data = mask_data(real_data, real_mask, imputed_data)

            if update_all_networks:
                mask_noise.normal_()
                fake_mask = mask_gen(mask_noise)
                masked_fake_data = mask_data(fake_data, fake_mask, tau)
                update_data_critic(masked_real_data, masked_fake_data)
                update_mask_critic(real_mask, fake_mask)

                sum_data_loss += update_data_critic.loss_value
                sum_mask_loss += update_mask_critic.loss_value

            update_impu_critic(fake_data, masked_imputed_data)
            sum_impu_loss += update_impu_critic.loss_value

            critic_updates += 1

            if critic_updates == n_critic:
                critic_updates = 0

                # Update generators' parameters
                if update_all_networks:
                    for p in data_critic.parameters():
                        p.requires_grad_(False)
                    for p in mask_critic.parameters():
                        p.requires_grad_(False)
                for p in impu_critic.parameters():
                    p.requires_grad_(False)

                impu_noise.uniform_()
                imputed_data = imputer(real_data, real_mask, impu_noise)
                masked_imputed_data = mask_data(real_data, real_mask,
                                                imputed_data)
                impu_loss = -impu_critic(masked_imputed_data).mean()

                if update_all_networks:
                    data_noise.normal_()
                    fake_data = data_gen(data_noise)
                    mask_noise.normal_()
                    fake_mask = mask_gen(mask_noise)
                    masked_fake_data = mask_data(fake_data, fake_mask, tau)
                    data_loss = -data_critic(masked_fake_data).mean()
                    mask_loss = -mask_critic(fake_mask).mean()

                    mask_gen.zero_grad()
                    (mask_loss + data_loss * alpha).backward(retain_graph=True)
                    mask_gen_optimizer.step()

                    data_noise.normal_()
                    fake_data = data_gen(data_noise)
                    mask_noise.normal_()
                    fake_mask = mask_gen(mask_noise)
                    masked_fake_data = mask_data(fake_data, fake_mask, tau)
                    data_loss = -data_critic(masked_fake_data).mean()

                    data_gen.zero_grad()
                    (data_loss + impu_loss * beta).backward(retain_graph=True)
                    data_gen_optimizer.step()

                imputer.zero_grad()
                if gamma > 0:
                    imputer_mismatch_loss = mask_norm(
                        (imputed_data - real_data)**2, real_mask)
                    (impu_loss + imputer_mismatch_loss * gamma).backward()
                else:
                    impu_loss.backward()
                imputer_optimizer.step()

                if update_all_networks:
                    for p in data_critic.parameters():
                        p.requires_grad_(True)
                    for p in mask_critic.parameters():
                        p.requires_grad_(True)
                for p in impu_critic.parameters():
                    p.requires_grad_(True)

        if update_all_networks:
            mean_data_loss = sum_data_loss / n_batch
            mean_mask_loss = sum_mask_loss / n_batch
            log['data loss', 'data_loss'].append(mean_data_loss)
            log['mask loss', 'mask_loss'].append(mean_mask_loss)
        mean_impu_loss = sum_impu_loss / n_batch
        log['imputer loss', 'impu_loss'].append(mean_impu_loss)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if update_all_networks:
                print('[{:4}] {:12.4f} {:12.4f} {:12.4f}'.format(
                    epoch, mean_data_loss, mean_mask_loss, mean_impu_loss))
            else:
                print('[{:4}] {:12.4f}'.format(epoch, mean_impu_loss))

            filename = f'{epoch:04d}.png'
            with torch.no_grad():
                data_gen.eval()
                mask_gen.eval()
                imputer.eval()

                data_noise.normal_()
                mask_noise.normal_()

                data_samples = data_gen(data_noise)
                plot_samples(data_samples, str(gen_data_dir / filename))

                mask_samples = mask_gen(mask_noise)
                plot_samples(mask_samples, str(gen_mask_dir / filename))

                # Plot imputation results
                impu_noise.uniform_()
                imputed_data = imputer(real_data, real_mask, impu_noise)
                imputed_data = mask_data(real_data, real_mask, imputed_data)
                if hasattr(data, 'mask_info'):
                    bbox = [data.mask_info[idx] for idx in index]
                else:
                    bbox = None
                plot_grid(imputed_data,
                          bbox,
                          gap=2,
                          save_file=str(impute_dir / filename))

                data_gen.train()
                mask_gen.train()
                imputer.train()

        for (name, shortname), trace in log.items():
            fig, ax = plt.subplots(figsize=(6, 4))
            ax.plot(trace)
            ax.set_ylabel(name)
            ax.set_xlabel('epoch')
            fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
            plt.close(fig)

        if save_model_interval > 0 and (epoch + 1) % save_model_interval == 0:
            save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)

        epoch_end = time.time()
        time_elapsed = epoch_end - start
        epoch_time = epoch_end - epoch_start
        epoch_start = epoch_end
        with (log_dir / 'epoch-time.txt').open('a') as f:
            print(epoch, epoch_time, time_elapsed, file=f)
        save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)

    print(output_dir)
Пример #7
0
def train_ssvae(args):
    if args.visualize:
        from plot import visualize_setup, plot_samples, plot_tsne
        visualize_setup(args.log_dir)

    # batch_size: number of images (and labels) to be considered in a batch
    ss_vae = SsVae(x_dim=p.NUM_PIXELS, y_dim=p.NUM_LABELS, **vars(args))

    # if you want to limit the datasets' entry size
    sizes = {"train_unsup": 200000, "train_sup": 1000, "dev": 1000}

    # prepare data loaders
    datasets, data_loaders = dict(), dict()
    for mode in ["train_unsup", "train_sup", "dev"]:
        datasets[mode] = Aspire(mode=mode, data_size=sizes[mode])
        data_loaders[mode] = AudioDataLoader(datasets[mode],
                                             batch_size=args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=True,
                                             use_cuda=args.use_cuda,
                                             pin_memory=True)

    # initializing local variables to maintain the best validation accuracy
    # seen across epochs over the supervised training set
    # and the corresponding testing set and the state of the networks
    best_valid_acc, corresponding_test_acc = 0.0, 0.0

    # run inference for a certain number of epochs
    for i in range(ss_vae.epoch, args.num_epochs):
        # get the losses for an epoch
        avg_losses_sup, avg_losses_unsup = ss_vae.train_epoch(data_loaders)
        # validate
        validation_accuracy = ss_vae.get_accuracy(data_loaders["dev"],
                                                  desc="validating")

        str_avg_loss_sup = ' '.join([f"{x:7.3f}" for x in avg_losses_sup])
        str_avg_loss_unsup = ' '.join([f"{x:7.3f}" for x in avg_losses_unsup])
        logger.info(f"epoch {ss_vae.epoch:03d}: "
                    f"avg_loss_sup {str_avg_loss_sup} "
                    f"avg_loss_unsup {str_avg_loss_unsup} "
                    f"val_accuracy {validation_accuracy:5.3f}")

        # update the best validation accuracy and the corresponding
        # testing accuracy and the state of the parent module (including the networks)
        if best_valid_acc < validation_accuracy:
            best_valid_acc = validation_accuracy
        # save
        ss_vae.save(get_model_file_path(args, f"epoch_{ss_vae.epoch:04d}"))
        # visualize the conditional samples
        if args.visualize:
            from plot import visualize_setup, plot_samples, plot_tsne
            plot_samples(ss_vae)
            #if epoch % 100 == 0:
            #    plot_tsne(ss_vae, data_loaders["test"], use_cuda=args.use_cuda)
        # increase epoch num
        ss_vae.epoch += 1

    # test
    test_accuracy = ss_vae.get_accuracy(data_loaders["test"])

    logger.info(f"best validation accuracy {best_valid_acc:5.3f} "
                f"test accuracy {test_accuracy:5.3f}")

    #save final model
    ss_vae.save(args, get_model_file_path("final"), epoch=epoch)
Пример #8
0
def plot(time_keeper, sampler):
    if( time_keeper.can_plot() ):
        time_keeper.plotting = True
        print "plotting.."
        plot_samples(sampler.samples)
        time_keeper.plotted()
Пример #9
0
def misgan(args,
           data_gen,
           mask_gen,
           data_critic,
           mask_critic,
           data,
           output_dir,
           checkpoint=None):
    n_critic = args.n_critic
    gp_lambda = args.gp_lambda
    batch_size = args.batch_size
    nz = args.n_latent
    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval
    alpha = args.alpha
    tau = args.tau

    gen_data_dir = mkdir(output_dir / 'img')
    gen_mask_dir = mkdir(output_dir / 'mask')
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    data_loader = DataLoader(data,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True)
    n_batch = len(data_loader)

    data_noise = torch.FloatTensor(batch_size, nz).to(device)
    mask_noise = torch.FloatTensor(batch_size, nz).to(device)

    # Interpolation coefficient
    eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)

    # For computing gradient penalty
    ones = torch.ones(batch_size).to(device)

    lrate = 1e-4
    # lrate = 1e-5
    data_gen_optimizer = optim.Adam(data_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))
    mask_gen_optimizer = optim.Adam(mask_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))

    data_critic_optimizer = optim.Adam(data_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))
    mask_critic_optimizer = optim.Adam(mask_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))

    update_data_critic = CriticUpdater(data_critic, data_critic_optimizer, eps,
                                       ones, gp_lambda)
    update_mask_critic = CriticUpdater(mask_critic, mask_critic_optimizer, eps,
                                       ones, gp_lambda)

    start_epoch = 0
    critic_updates = 0
    log = defaultdict(list)

    if checkpoint:
        data_gen.load_state_dict(checkpoint['data_gen'])
        mask_gen.load_state_dict(checkpoint['mask_gen'])
        data_critic.load_state_dict(checkpoint['data_critic'])
        mask_critic.load_state_dict(checkpoint['mask_critic'])
        data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
        mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
        data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
        mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
        start_epoch = checkpoint['epoch']
        critic_updates = checkpoint['critic_updates']
        log = checkpoint['log']

    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)

    def save_model(path, epoch, critic_updates=0):
        torch.save(
            {
                'data_gen': data_gen.state_dict(),
                'mask_gen': mask_gen.state_dict(),
                'data_critic': data_critic.state_dict(),
                'mask_critic': mask_critic.state_dict(),
                'data_gen_opt': data_gen_optimizer.state_dict(),
                'mask_gen_opt': mask_gen_optimizer.state_dict(),
                'data_critic_opt': data_critic_optimizer.state_dict(),
                'mask_critic_opt': mask_critic_optimizer.state_dict(),
                'epoch': epoch + 1,
                'critic_updates': critic_updates,
                'log': log,
                'args': args,
            }, str(path))

    sns.set()

    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        sum_data_loss, sum_mask_loss = 0, 0
        for real_data, real_mask, _, _ in data_loader:
            # Assume real_data and mask have the same number of channels.
            # Could be modified to handle multi-channel images and
            # single-channel masks.
            real_mask = real_mask.float()[:, None]

            real_data = real_data.to(device)
            real_mask = real_mask.to(device)

            masked_real_data = mask_data(real_data, real_mask, tau)

            # Update discriminators' parameters
            data_noise.normal_()
            mask_noise.normal_()

            fake_data = data_gen(data_noise)
            fake_mask = mask_gen(mask_noise)

            masked_fake_data = mask_data(fake_data, fake_mask, tau)

            update_data_critic(masked_real_data, masked_fake_data)
            update_mask_critic(real_mask, fake_mask)

            sum_data_loss += update_data_critic.loss_value
            sum_mask_loss += update_mask_critic.loss_value

            critic_updates += 1

            if critic_updates == n_critic:
                critic_updates = 0

                # Update generators' parameters
                for p in data_critic.parameters():
                    p.requires_grad_(False)
                for p in mask_critic.parameters():
                    p.requires_grad_(False)

                data_noise.normal_()
                mask_noise.normal_()

                fake_data = data_gen(data_noise)
                fake_mask = mask_gen(mask_noise)
                masked_fake_data = mask_data(fake_data, fake_mask, tau)

                data_loss = -data_critic(masked_fake_data).mean()
                data_gen.zero_grad()
                data_loss.backward()
                data_gen_optimizer.step()

                data_noise.normal_()
                mask_noise.normal_()

                fake_data = data_gen(data_noise)
                fake_mask = mask_gen(mask_noise)
                masked_fake_data = mask_data(fake_data, fake_mask, tau)

                data_loss = -data_critic(masked_fake_data).mean()
                mask_loss = -mask_critic(fake_mask).mean()
                mask_gen.zero_grad()
                (mask_loss + data_loss * alpha).backward()
                mask_gen_optimizer.step()

                for p in data_critic.parameters():
                    p.requires_grad_(True)
                for p in mask_critic.parameters():
                    p.requires_grad_(True)

        mean_data_loss = sum_data_loss / n_batch
        mean_mask_loss = sum_mask_loss / n_batch
        log['data loss', 'data_loss'].append(mean_data_loss)
        log['mask loss', 'mask_loss'].append(mean_mask_loss)

        for (name, shortname), trace in log.items():
            fig, ax = plt.subplots(figsize=(6, 4))
            ax.plot(trace)
            ax.set_ylabel(name)
            ax.set_xlabel('epoch')
            fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
            plt.close(fig)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            print(f'[{epoch:4}] {mean_data_loss:12.4f} {mean_mask_loss:12.4f}')

            filename = f'{epoch:04d}.png'

            data_gen.eval()
            mask_gen.eval()

            with torch.no_grad():
                data_noise.normal_()
                mask_noise.normal_()

                data_samples = data_gen(data_noise)
                plot_samples(data_samples, str(gen_data_dir / filename))

                mask_samples = mask_gen(mask_noise)
                plot_samples(mask_samples, str(gen_mask_dir / filename))

            data_gen.train()
            mask_gen.train()

        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)

        epoch_end = time.time()
        time_elapsed = epoch_end - start
        epoch_time = epoch_end - epoch_start
        epoch_start = epoch_end
        with (log_dir / 'time.txt').open('a') as f:
            print(epoch, epoch_time, time_elapsed, file=f)
        save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)

    print(output_dir)
Пример #10
0
def Simulation(ps,
               par_dict,
               num_param,
               redshifts=None,
               luminosities=None,
               durations=None,
               obs_pf=None,
               obs_t90=None,
               detector=None,
               options=None,
               vol_arr=None,
               kc=None,
               dl=None,
               plot_GRB=None,
               prior=False,
               sim=False,
               dsim=False,
               file=None):

    # Option for prior testing
    if prior is not False:
        ln_prior = 0
        for n in range(num_param):
            keyword = [x for x, y in par_dict.items() if y == n]
            ln_prior += np.log(prior_dist(ps[n], keyword[0]))
        return ln_prior

    # Plotting options
    if plot_GRB is None:
        plot_GRB = options.getboolean('plot_GRB')
    plot_func = options.getboolean('plotting')
    #sim_time = time.time()

    ## Redshifts
    # Collapsar redshift pdf [number / yr] all-sky
    redshift_pdf_coll, N_coll = source_rate_density(
        redshifts,
        rho0=ps[par_dict["coll rho0"]],
        z_star=ps[par_dict["coll z*"]],
        n1=ps[par_dict["coll z1"]],
        n2=ps[par_dict["coll z2"]],
        vol_arr=vol_arr,
        plot=plot_func)
    # Merger redshift pdf [number / yr] all-sky
    redshift_pdf_merg, N_merg = source_rate_density(
        redshifts,
        rho0=ps[par_dict["merg rho0"]],
        z_star=ps[par_dict["merg z*"]],
        n1=ps[par_dict["merg z1"]],
        n2=ps[par_dict["merg z2"]],
        vol_arr=vol_arr,
        plot=plot_func)

    # Correct for 11.5 years of GBM data
    # Apparently, some distributions make N=inf,
    # so we have to ensure for that
    try:
        N_coll = np.int(N_coll * 11.5)
        N_merg = np.int(N_merg * 11.5)
    except:
        return -np.inf
    print(N_coll, N_merg)

    # Draw random redshifts from collapsar pdf
    redshift_sample_coll = sample_distribution(redshifts,
                                               redshift_pdf_coll,
                                               xlabel='Redshift',
                                               ylog=False,
                                               num_draw=N_coll,
                                               plot=plot_func)
    # Draw random redshifts from merger pdf
    redshift_sample_merg = sample_distribution(redshifts,
                                               redshift_pdf_merg,
                                               xlabel='Redshift',
                                               ylog=False,
                                               num_draw=N_merg,
                                               plot=plot_func)

    # Plot randomly drawn samples and pdf to ensure sim is done correctly
    if plot_func is not False:
        plot_samples(redshift_sample_coll, redshifts, redshift_pdf_coll)
        plot_samples(redshift_sample_merg, redshifts, redshift_pdf_merg)

    # Get kcorrection for each redshift
    coll_kc = kc[np.searchsorted(redshifts, redshift_sample_coll)]
    merg_kc = kc[np.searchsorted(redshifts, redshift_sample_merg)]
    # Get luminosity distance for each redshift
    coll_dl = dl[np.searchsorted(redshifts, redshift_sample_coll)]
    merg_dl = dl[np.searchsorted(redshifts, redshift_sample_merg)]

    ## Luminosity
    # Draw Collapsar GRB Rest-frame luminosities
    lum_pdf_coll, const_coll = get_luminosity(luminosities,
                                              lstar=1E52,
                                              alpha=ps[par_dict["coll alpha"]],
                                              beta=ps[par_dict["coll beta"]],
                                              lmin=luminosities[0],
                                              lmax=luminosities[-1],
                                              plot=plot_func)
    # Draw merger grb rest-frame luminosities
    lum_pdf_merg, const_merg = get_luminosity(luminosities,
                                              lstar=1E52,
                                              alpha=ps[par_dict["merg alpha"]],
                                              beta=ps[par_dict["merg beta"]],
                                              lmin=luminosities[0],
                                              lmax=luminosities[-1],
                                              plot=plot_func)

    # Draw random collapsar luminosities from pdf
    lum_sample_coll = sample_distribution(
        luminosities,
        lum_pdf_coll,
        xlabel='Peak Luminosity (1-10,000 keV)[ergs/s]',
        num_draw=N_coll,
        xlog=True,
        ylog=False,
        plot=plot_func)
    # Draw random merger luminosities from pdf
    lum_sample_merg = sample_distribution(
        luminosities,
        lum_pdf_merg,
        xlabel='Peak Luminosity (1-10,000 keV)[ergs/s]',
        num_draw=N_merg,
        xlog=True,
        ylog=False,
        plot=plot_func)

    # Plot randomly drawn samples and pdf to ensure sim is done correctly
    if plot_func is not False:
        bins = np.logspace(50, 54, 60)
        plot_samples(lum_sample_coll,
                     luminosities,
                     lum_pdf_coll,
                     bins=bins,
                     xlog=True,
                     ylog=True)
        plot_samples(lum_sample_merg,
                     luminosities,
                     lum_pdf_merg,
                     bins=bins,
                     xlog=True,
                     ylog=True)

    ## Peak flux
    # Get model peak flux for simulated collapsar GRBs
    coll_pf_all, coll_pf, coll_z = Peak_flux(L=lum_sample_coll,
                                             z=redshift_sample_coll,
                                             kcorr=coll_kc,
                                             dl=coll_dl,
                                             plotting=plot_GRB,
                                             sim=sim,
                                             dsim=dsim,
                                             title='Collapsars')
    # Get model peak flux for simulated collapsar GRBs
    merg_pf_all, merg_pf, merg_z = Peak_flux(L=lum_sample_merg,
                                             z=redshift_sample_merg,
                                             kcorr=merg_kc,
                                             dl=merg_dl,
                                             plotting=plot_GRB,
                                             sim=sim,
                                             dsim=dsim,
                                             title='Mergers')
    '''
    ## Duration
    # Collapsar duration pdf
    coll_dur_pdf = intrinsic_duration(durations, mu=ps[par_dict["coll mu"]], sigma=ps[par_dict["coll sigma"]], plot=plot_func)
    dur_sample_coll = sample_distribution(durations, coll_dur_pdf, num_draw=len(coll_pf), plot=plot_func)

    # Merger duration pdf
    merg_dur_pdf = intrinsic_duration(durations, mu=ps[par_dict["merg mu"]], sigma=ps[par_dict["merg sigma"]], plot=plot_func)
    dur_sample_merg = sample_distribution(durations, merg_dur_pdf, num_draw=len(merg_pf), plot=plot_func)
    
    # Plot randomly drawn samples and pdf to ensure sim is done correctly
    #if plot_func is not False:
    bins = np.logspace(-2, 3, 60)
    plot_samples(dur_sample_coll, durations, coll_dur_pdf, bins=bins, xlog=True)
    plot_samples(dur_sample_merg, durations, merg_dur_pdf, bins=bins, xlog=True)
    
    
    #Also need to adjust duration energy range to match t90 somehow...
    
    # Observed duration is longer than source duration
    dur_sample_coll *= (coll_z + 1)
    dur_sample_merg *= (merg_z + 1)
    '''
    # Combine collapsar and merger model counts
    pf_model, pf_data = combine_data(coll_model=coll_pf,
                                     merg_model=merg_pf,
                                     data=obs_pf,
                                     coll_all=coll_pf_all,
                                     merg_all=merg_pf_all,
                                     coll_model_label='Collapsar Model',
                                     merg_model_label='Merger Model',
                                     data_label='GBM Data',
                                     show_plot=plot_GRB,
                                     sim=sim)

    # Combine collapsar and merger model durations
    #dur_model, dur_data = combine_data(coll_model=dur_sample_coll, merg_model=dur_sample_merg, data=obs_t90, coll_all=dur_sample_coll, merg_all=dur_sample_merg, show_dur_plot=True)#plot_GRB)

    # Save peak flux data
    if sim is not False and file is not None:
        print('Saving simulation data: ../' + file + '.npy')
        tot_model = np.concatenate([coll_pf, merg_pf])
        np.save('../' + file + '.npy', tot_model)

    # Calculate the likelihood for this model (i.e., these parameters)
    try:
        pf_llr = log_likelihood(model_counts=pf_model, data_counts=pf_data)
        dur_llr = 0  #log_likelihood(model_counts=dur_model, data_counts=dur_data)

        # set uniform priors for right now
        ln_prior = 0
        for n in range(num_param):
            keyword = [x for x, y in par_dict.items() if y == n]
            ln_prior += np.log(prior_dist(ps[n], keyword[0]))
        return pf_llr + dur_llr + ln_prior

    except:
        return -np.inf