예제 #1
0
파일: plugins.py 프로젝트: Separius/EEG-GAN
 def epoch(self, epoch_index):
     if epoch_index % self.output_snapshot_ticks != 0:
         return
     gc.collect()
     all_fakes = []
     all_reals = []
     with torch.no_grad():
         remaining_items = self.max_items
         while remaining_items > 0:
             z = next(self.trainer.random_latents_generator)
             fake_latents_in = cudize(z)
             all_fakes.append(
                 self.trainer.generator(fake_latents_in)[0]['x'].data.cpu())
             if all_fakes[-1].size(2) < self.patch_size:
                 break
             remaining_items -= all_fakes[-1].size(0)
         all_fakes = torch.cat(all_fakes, dim=0)
         remaining_items = self.max_items
         while remaining_items > 0:
             all_reals.append(next(self.trainer.dataiter)['x'])
             if all_reals[-1].size(2) < self.patch_size:
                 break
             remaining_items -= all_reals[-1].size(0)
         all_reals = torch.cat(all_reals, dim=0)
     swd = self.get_descriptors(all_fakes, all_reals)
     if len(swd) > 0:
         swd.append(np.array(swd).mean())
     self.trainer.stats['swd']['val'] = swd
     self.trainer.stats['swd']['epoch'] = epoch_index
예제 #2
0
파일: plugins.py 프로젝트: Separius/EEG-GAN
 def epoch(self, epoch_index):
     for p, avg_p in zip(self.trainer.generator.parameters(),
                         self.my_g_clone):
         avg_p.mul_(self.old_weight).add_((1.0 - self.old_weight) * p.data)
     if epoch_index % self.output_snapshot_ticks == 0:
         z = next(self.sample_fn(self.samples_count))
         gen_input = cudize(z)
         original_param = self.flatten_params(self.trainer.generator)
         self.load_params(self.my_g_clone, self.trainer.generator)
         dest = os.path.join(
             self.checkpoints_dir,
             SaverPlugin.last_pattern.format(
                 'smooth_generator',
                 '{:06}'.format(self.trainer.cur_nimg // 1000)))
         torch.save(
             {
                 'cur_nimg': self.trainer.cur_nimg,
                 'model': self.trainer.generator.state_dict()
             }, dest)
         out = generate_samples(self.trainer.generator, gen_input)
         self.load_params(original_param, self.trainer.generator)
         frequency = self.max_freq * out.shape[2] / self.seq_len
         images = self.get_images(frequency, epoch_index, out)
         for i, image in enumerate(images):
             imwrite(
                 os.path.join(self.checkpoints_dir,
                              '{}_{}.png'.format(epoch_index, i)), image)
예제 #3
0
 def train(self):
     if self.lr_scheduler_g is not None:
         self.lr_scheduler_g.step(self.cur_nimg / self.d_training_repeats)
     fake_latents_in = cudize(next(self.random_latents_generator))
     for i in range(self.d_training_repeats):
         if self.lr_scheduler_d is not None:
             self.lr_scheduler_d.step(self.cur_nimg)
         real_images_expr = cudize(next(self.dataiter))
         self.cur_nimg += real_images_expr['x'].size(0)
         d_loss = self.d_loss(self.discriminator, self.generator,
                              real_images_expr, fake_latents_in)
         d_loss.backward()
         self.optimizer_d.step()
         fake_latents_in = cudize(next(self.random_latents_generator))
     g_loss = self.g_loss(self.discriminator, self.generator,
                          real_images_expr, fake_latents_in)
     g_loss.backward()
     self.optimizer_g.step()
     self.iterations += 1
     self.call_plugins('iteration', self.iterations, *(g_loss, d_loss))
예제 #4
0
파일: plugins.py 프로젝트: Separius/EEG-GAN
 def epoch(self, epoch_index):
     if epoch_index % self.output_snapshot_ticks != 0:
         return
     values = []
     with torch.no_grad():
         i = 0
         for data in self.create_dataloader_fun(
                 min(self.trainer.stats['minibatch_size'],
                     1024), False, self.trainer.dataset.model_depth,
                 self.trainer.dataset.alpha):
             d_real, _, _ = self.trainer.discriminator(cudize(data))
             values.append(d_real.mean().item())
             i += 1
     values = np.array(values).mean()
     self.trainer.stats['memorization']['val'] = values
     self.trainer.stats['memorization']['epoch'] = epoch_index
예제 #5
0
파일: plugins.py 프로젝트: Separius/EEG-GAN
 def epoch(self, epoch_index):
     if epoch_index % self.output_snapshot_ticks != 0:
         return
     if self.last_stage != (self.trainer.dataset.model_depth +
                            self.trainer.dataset.alpha):
         self.last_stage = self.trainer.dataset.model_depth + self.trainer.dataset.alpha
         values = []
         i = 0
         for data in self.create_dataloader_fun(
                 min(self.trainer.stats['minibatch_size'],
                     1024), False, self.trainer.dataset.model_depth,
                 self.trainer.dataset.alpha):
             x = data['x']
             x = x.view(x.size(0), -1).numpy()
             values.append(x)
             i += x.shape[0]
             if i >= self.num_samples:
                 break
         values = np.stack(values)
         self.ndb = NDB(values,
                        self.num_bins,
                        cache_folder=self.output_dir,
                        stage=self.last_stage)
     with torch.no_grad():
         values = []
         i = 0
         while i < self.num_samples:
             fake_latents_in = cudize(
                 next(self.trainer.random_latents_generator))
             x = self.trainer.generator(fake_latents_in)[0]['x'].cpu()
             i += x.size(0)
             x = x.view(x.size(0), -1).numpy()
             values.append(x)
     values = np.stack(values)
     result = self.ndb.evaluate(values)
     self.trainer.stats['ndb']['ndb'] = result[0]
     self.trainer.stats['ndb']['js'] = result[1]
     self.trainer.stats['ndb']['epoch'] = epoch_index
예제 #6
0
def train_virtual_inception_network(x,
                                    num_virtual_classes=32,
                                    num_epochs=1000):
    model = cudize(
        torch.nn.Sequential(torch.nn.Linear(x.size(1), num_virtual_classes),
                            torch.nn.Softmax(dim=1)))
    optim = torch.optim.Adam(model.parameters(), lr=0.001)
    for i in range(num_epochs):
        pred = model(x)
        entropy = (pred.mean(dim=0) * torch.log(pred.mean(dim=0))).sum()
        sparsity = ((pred.max(dim=1)[0] - 1.0)**2).mean()
        loss = 1.0 * entropy + 2.5 * sparsity
        optim.zero_grad()
        loss.backward()
        optim.step()
    if test_mode:
        with torch.no_grad():
            print(model(x[:128]).max(dim=1)[0])
            print(model(x).max(dim=1)[0].mean().item())
            plt.hist(model(x).max(dim=1)[1].cpu().numpy(),
                     bins=num_virtual_classes)
            plt.show()
    return model
예제 #7
0
파일: plugins.py 프로젝트: Separius/EEG-GAN
 def __init__(self,
              num_channels,
              create_dataloader_fun,
              target_seq_len,
              num_samples=1024 * 16,
              output_snapshot_ticks=25,
              calc_for_z=True,
              calc_for_zp=True,
              calc_for_c=True,
              calc_for_cp=True):
     super().__init__([(1, 'epoch')])
     self.create_dataloader_fun = create_dataloader_fun
     self.output_snapshot_ticks = output_snapshot_ticks
     self.last_depth = -1
     self.last_alpha = -1
     self.target_seq_len = target_seq_len
     self.calc_z = calc_for_z
     self.calc_zp = calc_for_zp
     self.calc_c = calc_for_c
     self.calc_cp = calc_for_cp
     self.num_samples = num_samples
     hp = cpc_hp
     self.network = cudize(
         CpcNetwork(num_channels,
                    generate_long_sequence=hp.generate_long_sequence,
                    pooling=hp.pool_or_stride == 'pool',
                    encoder_dropout=hp.encoder_dropout,
                    use_sinc_encoder=hp.use_sinc_encoder,
                    use_shared_sinc=hp.use_shared_sinc,
                    bidirectional=hp.bidirectional,
                    contextualizer_num_layers=hp.contextualizer_num_layers,
                    contextualizer_dropout=hp.contextualizer_dropout,
                    use_transformer=hp.use_transformer,
                    causal_prediction=hp.causal_prediction,
                    prediction_k=hp.prediction_k,
                    encoder_activation=hp.encoder_activation,
                    tiny_encoder=hp.tiny_encoder)).eval()
예제 #8
0
파일: plugins.py 프로젝트: Separius/EEG-GAN
 def epoch(self, epoch_index):
     if epoch_index % self.output_snapshot_ticks != 0:
         return
     if not (self.last_depth == self.trainer.dataset.model_depth
             and self.last_alpha == self.trainer.dataset.alpha):
         with torch.no_grad():
             all_z = []
             all_c = []
             all_zp = []
             all_cp = []
             i = 0
             for data in self.create_dataloader_fun(
                     min(self.trainer.stats['minibatch_size'],
                         1024), False, self.trainer.dataset.model_depth,
                     self.trainer.dataset.alpha):
                 x = cudize(data['x'])
                 x = resample_signal(x, x.size(2), self.target_seq_len,
                                     True)
                 batch_size = x.size(0)
                 z, c, zp, cp = self.network.inference_forward(x)
                 if self.calc_z:
                     all_z.append(z.view(-1, z.size(1)).cpu())
                 if self.calc_c:
                     all_c.append(c.contiguous().view(-1, c.size(1)).cpu())
                 if self.calc_zp:
                     all_zp.append(zp.cpu())
                 if self.calc_cp:
                     all_cp.append(cp.cpu())
                 i += batch_size
                 if i >= self.num_samples:
                     break
             if self.calc_z:
                 all_z = torch.cat(all_z, dim=0)
                 self.z_mu, self.z_std = torch.mean(
                     all_z, 0), self.torch_cov(all_z, rowvar=False)
             if self.calc_c:
                 all_c = torch.cat(all_c, dim=0)
                 self.c_mu, self.c_std = torch.mean(
                     all_c, 0), self.torch_cov(all_c, rowvar=False)
             if self.calc_zp:
                 all_zp = torch.cat(all_zp, dim=0)
                 self.zp_mu, self.zp_std = torch.mean(
                     all_zp, 0), self.torch_cov(all_zp, rowvar=False)
             if self.calc_c:
                 all_cp = torch.cat(all_cp, dim=0)
                 self.cp_mu, self.cp_std = torch.mean(
                     all_cp, 0), self.torch_cov(all_cp, rowvar=False)
     with torch.no_grad():
         i = 0
         all_z = []
         all_c = []
         all_zp = []
         all_cp = []
         while i < self.num_samples:
             fake_latents_in = cudize(
                 next(self.trainer.random_latents_generator))
             x = self.trainer.generator(fake_latents_in)[0]['x']
             x = resample_signal(x, x.size(2), self.target_seq_len, True)
             z, c, zp, cp = self.network.inference_forward(x)
             if self.calc_z:
                 all_z.append(z.view(-1, z.size(1)).cpu())
             if self.calc_c:
                 all_c.append(c.view(-1, z.size(1)).cpu())
             if self.calc_zp:
                 all_zp.append(zp.cpu())
             if self.calc_cp:
                 all_cp.append(cp.cpu())
         if self.calc_z:
             all_z = torch.cat(all_z, dim=0)
             fz_mu, fz_std = torch.mean(all_z,
                                        0), self.torch_cov(all_z,
                                                           rowvar=False)
             self.trainer.stats['FID']['z_fake'] = self.calc_fid(
                 fz_mu, fz_std, self.z_mu, self.z_std)
         if self.calc_c:
             all_c = torch.cat(all_c, dim=0)
             fc_mu, fc_std = torch.mean(all_c,
                                        0), self.torch_cov(all_c,
                                                           rowvar=False)
             self.trainer.stats['FID']['c_fake'] = self.calc_fid(
                 fc_mu, fc_std, self.c_mu, self.c_std)
         if self.calc_zp:
             all_zp = torch.cat(all_zp, dim=0)
             fzp_mu, fzp_std = torch.mean(all_zp,
                                          0), self.torch_cov(all_zp,
                                                             rowvar=False)
             self.trainer.stats['FID']['zp_fake'] = self.calc_fid(
                 fzp_mu, fzp_std, self.zp_mu, self.zp_std)
         if self.calc_c:
             all_cp = torch.cat(all_cp, dim=0)
             fcp_mu, fcp_std = torch.mean(all_cp,
                                          0), self.torch_cov(all_cp,
                                                             rowvar=False)
             self.trainer.stats['FID']['cp_fake'] = self.calc_fid(
                 fcp_mu, fcp_std, self.cp_mu, self.cp_std)
         self.trainer.stats['FID']['epoch'] = epoch_index
예제 #9
0
def main(summary):
    train_dataset, val_dataset = EEGDataset.from_config(
        validation_ratio=hp.validation_ratio,
        validation_seed=hp.validation_seed,
        dir_path='./data/prepared_eegs_mat_th5',
        data_sampling_freq=220,
        start_sampling_freq=1,
        end_sampling_freq=60,
        start_seq_len=32,
        num_channels=17,
        return_long=False)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=hp.batch_size,
                                  num_workers=0,
                                  drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=hp.batch_size,
                                num_workers=0,
                                drop_last=False,
                                pin_memory=True)
    network = cudize(
        Network(train_dataset.num_channels,
                bidirectional=False,
                contextualizer_num_layers=hp.contextualizer_num_layers,
                contextualizer_dropout=hp.contextualizer_dropout,
                use_transformer=hp.use_transformer,
                prediction_k=hp.prediction_k *
                (hp.prediction_loss_weight != 0.0),
                have_global=(hp.global_loss_weight != 0.0),
                have_local=(hp.local_loss_weight != 0.0),
                residual_encoder=hp.residual_encoder,
                sinc_encoder=hp.sinc_encoder))
    num_parameters = num_params(network)
    print('num_parameters', num_parameters)
    if hp.use_bert_adam:
        network_optimizer = BertAdam(network.parameters(),
                                     lr=hp.lr,
                                     weight_decay=hp.weight_decay,
                                     warmup=0.2,
                                     t_total=hp.epochs * len(train_dataloader),
                                     schedule='warmup_linear')
    else:
        network_optimizer = Adam(network.parameters(),
                                 lr=hp.lr,
                                 weight_decay=hp.weight_decay)
    if hp.use_scheduler:
        scheduler = ReduceLROnPlateau(network_optimizer,
                                      patience=3,
                                      verbose=True)
    best_val_loss = float('inf')
    for epoch in trange(hp.epochs):
        for training, data_loader in zip((False, True),
                                         (val_dataloader, train_dataloader)):
            if training:
                if epoch == hp.epochs - 1:
                    break
                network.train()
            else:
                network.eval()
            total_network_loss = 0.0
            total_prediction_loss = 0.0
            total_global_loss = 0.0
            total_local_loss = 0.0
            total_global_accuracy = 0.0
            total_local_accuracy = 0.0
            total_k_pred_acc = {}
            total_pred_acc = 0.0
            total_count = 0
            with torch.set_grad_enabled(training):
                for batch in data_loader:
                    x = cudize(batch['x'])
                    network_return = network.forward(x)
                    network_loss = hp.prediction_loss_weight * network_return.losses.prediction_
                    network_loss = network_loss + hp.global_loss_weight * network_return.losses.global_
                    network_loss = network_loss + hp.local_loss_weight * network_return.losses.local_

                    bs = x.size(0)
                    total_count += bs
                    total_network_loss += network_loss.item() * bs
                    total_prediction_loss += network_return.losses.prediction_.item(
                    ) * bs
                    total_global_loss += network_return.losses.global_.item(
                    ) * bs
                    total_local_loss += network_return.losses.local_.item(
                    ) * bs

                    total_global_accuracy += network_return.accuracies.global_ * bs
                    total_local_accuracy += network_return.accuracies.local_ * bs
                    dict_add(total_k_pred_acc,
                             network_return.accuracies.prediction_, bs)
                    len_pred = len(network_return.accuracies.prediction_)
                    if len_pred > 0:
                        total_pred_acc += sum(
                            network_return.accuracies.prediction_.values(
                            )) / len_pred * bs

                    if training:
                        network_optimizer.zero_grad()
                        network_loss.backward()
                        network_optimizer.step()

            metrics = dict(net_loss=total_network_loss)
            if network.prediction_loss_network.k > 0 and hp.prediction_loss_weight != 0:
                metrics.update(
                    dict(avg_prediction_acc=total_pred_acc,
                         prediction_loss=total_prediction_loss,
                         k_prediction_acc=total_k_pred_acc))
            if hp.global_loss_weight != 0:
                metrics.update(
                    dict(global_loss=total_global_loss,
                         global_acc=total_global_accuracy))
            if hp.local_loss_weight != 0:
                metrics.update(
                    dict(local_loss=total_local_loss,
                         local_acc=total_local_accuracy))
            divide_dict(metrics, total_count)

            if not training and hp.use_scheduler:
                scheduler.step(metrics['net_loss'])
            if summary:
                print('train' if training else 'validation', epoch,
                      metrics['net_loss'])
            else:
                print('train' if training else 'validation', epoch)
                print(json.dumps(metrics, indent=4))
            if not training and (metrics['net_loss'] < best_val_loss):
                best_val_loss = metrics['net_loss']
                print('update best to', best_val_loss)
                torch.save(network.state_dict(), 'best_network.pth')
예제 #10
0
def main(num_samples):
    # NOTE, this are model dependent and it's far better to read them from a yml file
    skip_depth = 6 if test_mode else 0
    progression_scale_up = EEGDataset.progression_scale_up
    progression_scale_down = EEGDataset.progression_scale_down
    train_dataset, val_dataset = ThinEEGDataset.from_config(
        validation_ratio=hp.validation_ratio,
        stride=hp.ds_stride,
        dir_path='./data/tuh1/',
        num_channels=hp.num_channels)
    real_ndb = None
    final_result = {}
    if test_mode:
        models_zip = zip([hp], ['default_-7.531810902716695'])
    else:
        models_zip = zip([hp, prediction_hp, local_hp], [
            'default_-7.531810902716695', 'prediction_-2.450593529493725',
            'local_-0.5012809535156667'
        ])
    for current_hp, model_address in models_zip:
        network = Network(
            train_dataset.num_channels,
            generate_long_sequence=current_hp.generate_long_sequence,
            pooling=current_hp.pool_or_stride == 'pool',
            encoder_dropout=current_hp.encoder_dropout,
            use_sinc_encoder=current_hp.use_sinc_encoder,
            use_shared_sinc=current_hp.use_shared_sinc,
            bidirectional=current_hp.bidirectional,
            contextualizer_num_layers=current_hp.contextualizer_num_layers,
            contextualizer_dropout=current_hp.contextualizer_dropout,
            use_transformer=current_hp.use_transformer,
            causal_prediction=current_hp.causal_prediction,
            prediction_k=current_hp.prediction_k,
            encoder_activation=current_hp.encoder_activation,
            tiny_encoder=current_hp.tiny_encoder)
        network.load_state_dict(
            torch.load('./results/cpc_trained/' + model_address + '.pth',
                       map_location='cpu'))
        network = cudize(network.eval())
        collected_results = []
        print('loaded', model_address)
        for i in range(2 if test_mode else 10):  # for stability checks
            if test_mode: print(model_address, 'run #{}'.format(i))
            real_stats, real_ndb = calculate_stats(train_dataset, network,
                                                   progression_scale_up,
                                                   progression_scale_down,
                                                   skip_depth, num_samples,
                                                   'normal', current_hp,
                                                   real_ndb, None)
            if test_mode: print('real stats calculated')
            val_stats, _ = calculate_stats(val_dataset, network,
                                           progression_scale_up,
                                           progression_scale_down, skip_depth,
                                           num_samples, 'validation',
                                           current_hp, real_ndb, real_stats)
            if test_mode: print('val stats calculated')
            permuted_stats, _ = calculate_stats(train_dataset, network,
                                                progression_scale_up,
                                                progression_scale_down,
                                                skip_depth, num_samples,
                                                'permute', current_hp,
                                                real_ndb, real_stats)
            if test_mode: print('permuted stats calculated')
            shifted_stats, _ = calculate_stats(train_dataset, network,
                                               progression_scale_up,
                                               progression_scale_down,
                                               skip_depth, num_samples,
                                               'shift', current_hp, real_ndb,
                                               real_stats)
            if test_mode: print('shifted stats calculated')
            concatenated_stats, _ = calculate_stats(train_dataset, network,
                                                    progression_scale_up,
                                                    progression_scale_down,
                                                    skip_depth, num_samples,
                                                    'concat', current_hp,
                                                    real_ndb, real_stats)
            if test_mode: print('concatenated stats calculated')
            tiny_stats, _ = calculate_stats(train_dataset, network,
                                            progression_scale_up,
                                            progression_scale_down, skip_depth,
                                            num_samples, 'tiny', current_hp,
                                            real_ndb, real_stats)
            if test_mode: print('tiny stats calculated')
            zeroed_stats, _ = calculate_stats(train_dataset, network,
                                              progression_scale_up,
                                              progression_scale_down,
                                              skip_depth, num_samples, 'zero',
                                              current_hp, real_ndb, real_stats)
            if test_mode: print('zeroed stats calculated')
            noised_stats, _ = calculate_stats(train_dataset, network,
                                              progression_scale_up,
                                              progression_scale_down,
                                              skip_depth, num_samples, 'noise',
                                              current_hp, real_ndb, real_stats)
            if test_mode: print('noised stats calculated')
            collected_results.append({
                **real_stats,
                **shifted_stats,
                **concatenated_stats,
                **tiny_stats,
                **zeroed_stats,
                **noised_stats
            })
            # TODO (over time and different truncation threshold)
            # normal_generated_stats, _ = calculate_stats(train_dataset, network, progression_scale_up,
            #                                             progression_scale_down, skip_depth, num_samples, 'generated',
            #                                             current_hp, real_ndb, real_stats)
            # averaged_generated_stats, _ = calculate_stats(train_dataset, network, progression_scale_up,
            #                                               progression_scale_down, skip_depth, num_samples, 'averaged',
            #                                               current_hp, real_ndb, real_stats)
            # truncated_generated_stats, _ = calculate_stats(train_dataset, network, progression_scale_up,
            #                                                progression_scale_down, skip_depth, num_samples,
            #                                                'truncation', current_hp, real_ndb, real_stats)
        # calc std for each stats ([{key(mode_n): {key(seq_len): {meter: value}}}])
        final_result[model_address] = {
            mode: {
                seq_len: {
                    meter: (np.mean([
                        collected_results[j][mode][seq_len][meter]
                        for j in range(2 if test_mode else 10)
                    ]),
                            np.std([
                                collected_results[j][mode][seq_len][meter]
                                for j in range(2 if test_mode else 10)
                            ]))
                    for meter in {
                        'prediction_loss', 'prediction_acc', 'global_loss',
                        'global_acc', 'local_loss', 'local_acc',
                        'c_fid_max_seq_len', 'z_fid_max_seq_len',
                        'cp_fid_max_seq_len', 'zp_fid_max_seq_len', 'net_loss',
                        'ndb_score', 'ndb_js', 'c_fid', 'z_fid', 'zp_fid',
                        'cp_fid', 'z_vis_max_seq_len', 'z_vis',
                        'c_vis_max_seq_len', 'c_vis', 'zp_vis_max_seq_len',
                        'zp_vis', 'cp_vis_max_seq_len', 'cp_vis'
                    }
                }
                for seq_len in collected_results[0][mode].keys()
            }
            for mode in collected_results[0].keys()
        }
    return final_result
예제 #11
0
def calculate_network_stats(x, net: Network, scale_up, scale_down, skip_depth,
                            current_hp, real_ndb, real_stats):
    max_seq_len = x.size(2)
    seq_lens = [max_seq_len]
    for i in reversed(range(skip_depth, len(scale_up))):
        seq_lens = [int(seq_lens[0] * scale_down[i] / scale_up[i])] + seq_lens
    seq_lens = reversed(seq_lens)
    stats = {}
    with torch.no_grad():
        for seq_len in seq_lens:
            this_x = resample_signal(x, max_seq_len, seq_len, True)
            ndb, js = real_ndb[seq_len].evaluate(
                this_x.view(this_x.size(0), -1).numpy())
            this_x = resample_signal(this_x, seq_len, max_seq_len,
                                     True)  # resample to give to the net
            total_network_loss = 0.0
            total_prediction_loss = 0.0
            total_global_discriminator_loss = 0.0
            total_local_discriminator_loss = 0.0
            total_global_accuracy_one = 0.0
            total_global_accuracy_two = 0.0
            total_local_accuracy_one = 0.0
            total_local_accuracy_two = 0.0
            total_pred_acc = {}
            total_count = 0
            all_z = []
            all_c = []
            all_zp = []
            all_cp = []
            for i in range(x.size(0) // 128):
                prediction_loss, global_discriminator_loss, local_discriminator_loss, cp, global_accuracy, local_accuracy, pred_accuracy, z, c, zp = net.complete_forward(
                    cudize(this_x[i * 128:(i + 1) * 128]))
                global_accuracy_one, global_accuracy_two = global_accuracy
                local_accuracy_one, local_accuracy_two = local_accuracy
                network_loss = current_hp.prediction_loss_weight * prediction_loss + current_hp.global_loss_weight * global_discriminator_loss + current_hp.local_loss_weight * local_discriminator_loss
                this_batch_size = this_x[i * 128:(i + 1) * 128].size(0)
                total_count += this_batch_size
                total_network_loss += network_loss.item() * this_batch_size
                total_prediction_loss += prediction_loss.item(
                ) * this_batch_size
                total_global_discriminator_loss += global_discriminator_loss.item(
                ) * this_batch_size
                total_local_discriminator_loss += local_discriminator_loss.item(
                ) * this_batch_size
                total_global_accuracy_one += global_accuracy_one * this_batch_size
                total_global_accuracy_two += global_accuracy_two * this_batch_size
                total_local_accuracy_one += local_accuracy_one * this_batch_size
                total_local_accuracy_two += local_accuracy_two * this_batch_size
                dict_add(total_pred_acc, pred_accuracy, this_batch_size)
                # subsample z and c (assumes T = 32)
                all_z.append(z[:8].contiguous().view(-1, z.size(1)).cpu())
                all_c.append(c[:8].contiguous().view(-1, c.size(1)).cpu())
                all_zp.append(zp.cpu())
                all_cp.append(cp.cpu())

            total_global_accuracy_one /= total_count
            total_global_accuracy_two /= total_count
            total_local_accuracy_one /= total_count
            total_local_accuracy_two /= total_count
            divide_dict(total_pred_acc, total_count)

            total_prediction_loss /= total_count
            total_pred_acc = merge_pred_accs(
                total_pred_acc, net.prediction_loss_network.k,
                net.prediction_loss_network.bidirectional)
            total_global_discriminator_loss /= total_count
            total_global_accuracy = (total_global_accuracy_one +
                                     total_global_accuracy_two) / 2
            total_local_discriminator_loss /= total_count
            total_local_accuracy = (total_local_accuracy_one +
                                    total_local_accuracy_two) / 2
            total_network_loss /= total_count

            metrics = dict(prediction_loss=total_prediction_loss,
                           prediction_acc=total_pred_acc,
                           global_loss=total_global_discriminator_loss,
                           global_acc=total_global_accuracy,
                           local_loss=total_local_discriminator_loss,
                           local_acc=total_local_accuracy,
                           net_loss=total_network_loss,
                           ndb_score=ndb,
                           ndb_js=js)
            if test_mode:
                print(metrics)
            all_z = torch.cat(all_z, dim=0)
            all_c = torch.cat(all_c, dim=0)
            all_zp = torch.cat(all_zp, dim=0)
            all_cp = torch.cat(all_cp, dim=0)
            for (name, all_name) in zip(['z', 'c', 'zp', 'cp'],
                                        [all_z, all_c, all_zp, all_cp]):
                all_name = cudize(all_name)
                mean_cov = calc_mean_cov(name, all_name)
                if real_stats is None:
                    trained_vin = train_virtual_inception_network(all_name)
                    metrics.update({
                        **mean_cov, name + '_fid':
                        0.0,
                        name + '_virtual_inception_network':
                        trained_vin
                    })
                else:
                    trained_vin = real_stats['normal'][seq_len][
                        name + '_virtual_inception_network']
                    metrics.update({
                        name + '_fid':
                        FidCalculator.calc_fid(
                            real_stats['normal'][seq_len][name + '_mean'],
                            real_stats['normal'][seq_len][name + '_cov'],
                            mean_cov[name + '_mean'], mean_cov[name + '_cov'])
                    })
                metrics.update({
                    name + '_vis':
                    calculate_inception_score(all_name, trained_vin)
                })
                if seq_len == max_seq_len:
                    metrics.update({
                        name + '_fid_max_seq_len':
                        metrics[name + '_fid'],
                        name + '_vis_max_seq_len':
                        metrics[name + '_vis']
                    })
                else:
                    if real_stats is None:
                        working_real_stats = stats[max_seq_len]
                    else:
                        working_real_stats = real_stats['normal'][max_seq_len]
                    max_mean = working_real_stats[name + '_mean']
                    max_cov = working_real_stats[name + '_cov']
                    max_vin = working_real_stats[name +
                                                 '_virtual_inception_network']
                    metrics.update({
                        name + '_fid_max_seq_len':
                        FidCalculator.calc_fid(max_mean, max_cov,
                                               mean_cov[name + '_mean'],
                                               mean_cov[name + '_cov']),
                        name + '_vis_max_seq_len':
                        calculate_inception_score(all_name, max_vin)
                    })
            stats[seq_len] = metrics
    return stats
예제 #12
0
파일: dataset.py 프로젝트: Separius/EEG-GAN
 def collate_real(batch):
     return cudize(default_collate(batch))
예제 #13
0
def get_zero(batch_size):
    global zero
    if zero is None or batch_size != zero.size(0):
        zero = cudize(torch.zeros(batch_size))
    return zero
예제 #14
0
def get_one(batch_size):
    global one
    if one is None or batch_size != one.size(0):
        one = cudize(torch.ones(batch_size))
    return one
예제 #15
0
def get_mixing_factor(batch_size):
    global mixing_factors
    if mixing_factors is None or batch_size != mixing_factors.size(0):
        mixing_factors = cudize(torch.FloatTensor(batch_size, 1, 1))
    mixing_factors.uniform_()
    return mixing_factors