def epoch(self, epoch_index): if epoch_index % self.output_snapshot_ticks != 0: return gc.collect() all_fakes = [] all_reals = [] with torch.no_grad(): remaining_items = self.max_items while remaining_items > 0: z = next(self.trainer.random_latents_generator) fake_latents_in = cudize(z) all_fakes.append( self.trainer.generator(fake_latents_in)[0]['x'].data.cpu()) if all_fakes[-1].size(2) < self.patch_size: break remaining_items -= all_fakes[-1].size(0) all_fakes = torch.cat(all_fakes, dim=0) remaining_items = self.max_items while remaining_items > 0: all_reals.append(next(self.trainer.dataiter)['x']) if all_reals[-1].size(2) < self.patch_size: break remaining_items -= all_reals[-1].size(0) all_reals = torch.cat(all_reals, dim=0) swd = self.get_descriptors(all_fakes, all_reals) if len(swd) > 0: swd.append(np.array(swd).mean()) self.trainer.stats['swd']['val'] = swd self.trainer.stats['swd']['epoch'] = epoch_index
def epoch(self, epoch_index): for p, avg_p in zip(self.trainer.generator.parameters(), self.my_g_clone): avg_p.mul_(self.old_weight).add_((1.0 - self.old_weight) * p.data) if epoch_index % self.output_snapshot_ticks == 0: z = next(self.sample_fn(self.samples_count)) gen_input = cudize(z) original_param = self.flatten_params(self.trainer.generator) self.load_params(self.my_g_clone, self.trainer.generator) dest = os.path.join( self.checkpoints_dir, SaverPlugin.last_pattern.format( 'smooth_generator', '{:06}'.format(self.trainer.cur_nimg // 1000))) torch.save( { 'cur_nimg': self.trainer.cur_nimg, 'model': self.trainer.generator.state_dict() }, dest) out = generate_samples(self.trainer.generator, gen_input) self.load_params(original_param, self.trainer.generator) frequency = self.max_freq * out.shape[2] / self.seq_len images = self.get_images(frequency, epoch_index, out) for i, image in enumerate(images): imwrite( os.path.join(self.checkpoints_dir, '{}_{}.png'.format(epoch_index, i)), image)
def train(self): if self.lr_scheduler_g is not None: self.lr_scheduler_g.step(self.cur_nimg / self.d_training_repeats) fake_latents_in = cudize(next(self.random_latents_generator)) for i in range(self.d_training_repeats): if self.lr_scheduler_d is not None: self.lr_scheduler_d.step(self.cur_nimg) real_images_expr = cudize(next(self.dataiter)) self.cur_nimg += real_images_expr['x'].size(0) d_loss = self.d_loss(self.discriminator, self.generator, real_images_expr, fake_latents_in) d_loss.backward() self.optimizer_d.step() fake_latents_in = cudize(next(self.random_latents_generator)) g_loss = self.g_loss(self.discriminator, self.generator, real_images_expr, fake_latents_in) g_loss.backward() self.optimizer_g.step() self.iterations += 1 self.call_plugins('iteration', self.iterations, *(g_loss, d_loss))
def epoch(self, epoch_index): if epoch_index % self.output_snapshot_ticks != 0: return values = [] with torch.no_grad(): i = 0 for data in self.create_dataloader_fun( min(self.trainer.stats['minibatch_size'], 1024), False, self.trainer.dataset.model_depth, self.trainer.dataset.alpha): d_real, _, _ = self.trainer.discriminator(cudize(data)) values.append(d_real.mean().item()) i += 1 values = np.array(values).mean() self.trainer.stats['memorization']['val'] = values self.trainer.stats['memorization']['epoch'] = epoch_index
def epoch(self, epoch_index): if epoch_index % self.output_snapshot_ticks != 0: return if self.last_stage != (self.trainer.dataset.model_depth + self.trainer.dataset.alpha): self.last_stage = self.trainer.dataset.model_depth + self.trainer.dataset.alpha values = [] i = 0 for data in self.create_dataloader_fun( min(self.trainer.stats['minibatch_size'], 1024), False, self.trainer.dataset.model_depth, self.trainer.dataset.alpha): x = data['x'] x = x.view(x.size(0), -1).numpy() values.append(x) i += x.shape[0] if i >= self.num_samples: break values = np.stack(values) self.ndb = NDB(values, self.num_bins, cache_folder=self.output_dir, stage=self.last_stage) with torch.no_grad(): values = [] i = 0 while i < self.num_samples: fake_latents_in = cudize( next(self.trainer.random_latents_generator)) x = self.trainer.generator(fake_latents_in)[0]['x'].cpu() i += x.size(0) x = x.view(x.size(0), -1).numpy() values.append(x) values = np.stack(values) result = self.ndb.evaluate(values) self.trainer.stats['ndb']['ndb'] = result[0] self.trainer.stats['ndb']['js'] = result[1] self.trainer.stats['ndb']['epoch'] = epoch_index
def train_virtual_inception_network(x, num_virtual_classes=32, num_epochs=1000): model = cudize( torch.nn.Sequential(torch.nn.Linear(x.size(1), num_virtual_classes), torch.nn.Softmax(dim=1))) optim = torch.optim.Adam(model.parameters(), lr=0.001) for i in range(num_epochs): pred = model(x) entropy = (pred.mean(dim=0) * torch.log(pred.mean(dim=0))).sum() sparsity = ((pred.max(dim=1)[0] - 1.0)**2).mean() loss = 1.0 * entropy + 2.5 * sparsity optim.zero_grad() loss.backward() optim.step() if test_mode: with torch.no_grad(): print(model(x[:128]).max(dim=1)[0]) print(model(x).max(dim=1)[0].mean().item()) plt.hist(model(x).max(dim=1)[1].cpu().numpy(), bins=num_virtual_classes) plt.show() return model
def __init__(self, num_channels, create_dataloader_fun, target_seq_len, num_samples=1024 * 16, output_snapshot_ticks=25, calc_for_z=True, calc_for_zp=True, calc_for_c=True, calc_for_cp=True): super().__init__([(1, 'epoch')]) self.create_dataloader_fun = create_dataloader_fun self.output_snapshot_ticks = output_snapshot_ticks self.last_depth = -1 self.last_alpha = -1 self.target_seq_len = target_seq_len self.calc_z = calc_for_z self.calc_zp = calc_for_zp self.calc_c = calc_for_c self.calc_cp = calc_for_cp self.num_samples = num_samples hp = cpc_hp self.network = cudize( CpcNetwork(num_channels, generate_long_sequence=hp.generate_long_sequence, pooling=hp.pool_or_stride == 'pool', encoder_dropout=hp.encoder_dropout, use_sinc_encoder=hp.use_sinc_encoder, use_shared_sinc=hp.use_shared_sinc, bidirectional=hp.bidirectional, contextualizer_num_layers=hp.contextualizer_num_layers, contextualizer_dropout=hp.contextualizer_dropout, use_transformer=hp.use_transformer, causal_prediction=hp.causal_prediction, prediction_k=hp.prediction_k, encoder_activation=hp.encoder_activation, tiny_encoder=hp.tiny_encoder)).eval()
def epoch(self, epoch_index): if epoch_index % self.output_snapshot_ticks != 0: return if not (self.last_depth == self.trainer.dataset.model_depth and self.last_alpha == self.trainer.dataset.alpha): with torch.no_grad(): all_z = [] all_c = [] all_zp = [] all_cp = [] i = 0 for data in self.create_dataloader_fun( min(self.trainer.stats['minibatch_size'], 1024), False, self.trainer.dataset.model_depth, self.trainer.dataset.alpha): x = cudize(data['x']) x = resample_signal(x, x.size(2), self.target_seq_len, True) batch_size = x.size(0) z, c, zp, cp = self.network.inference_forward(x) if self.calc_z: all_z.append(z.view(-1, z.size(1)).cpu()) if self.calc_c: all_c.append(c.contiguous().view(-1, c.size(1)).cpu()) if self.calc_zp: all_zp.append(zp.cpu()) if self.calc_cp: all_cp.append(cp.cpu()) i += batch_size if i >= self.num_samples: break if self.calc_z: all_z = torch.cat(all_z, dim=0) self.z_mu, self.z_std = torch.mean( all_z, 0), self.torch_cov(all_z, rowvar=False) if self.calc_c: all_c = torch.cat(all_c, dim=0) self.c_mu, self.c_std = torch.mean( all_c, 0), self.torch_cov(all_c, rowvar=False) if self.calc_zp: all_zp = torch.cat(all_zp, dim=0) self.zp_mu, self.zp_std = torch.mean( all_zp, 0), self.torch_cov(all_zp, rowvar=False) if self.calc_c: all_cp = torch.cat(all_cp, dim=0) self.cp_mu, self.cp_std = torch.mean( all_cp, 0), self.torch_cov(all_cp, rowvar=False) with torch.no_grad(): i = 0 all_z = [] all_c = [] all_zp = [] all_cp = [] while i < self.num_samples: fake_latents_in = cudize( next(self.trainer.random_latents_generator)) x = self.trainer.generator(fake_latents_in)[0]['x'] x = resample_signal(x, x.size(2), self.target_seq_len, True) z, c, zp, cp = self.network.inference_forward(x) if self.calc_z: all_z.append(z.view(-1, z.size(1)).cpu()) if self.calc_c: all_c.append(c.view(-1, z.size(1)).cpu()) if self.calc_zp: all_zp.append(zp.cpu()) if self.calc_cp: all_cp.append(cp.cpu()) if self.calc_z: all_z = torch.cat(all_z, dim=0) fz_mu, fz_std = torch.mean(all_z, 0), self.torch_cov(all_z, rowvar=False) self.trainer.stats['FID']['z_fake'] = self.calc_fid( fz_mu, fz_std, self.z_mu, self.z_std) if self.calc_c: all_c = torch.cat(all_c, dim=0) fc_mu, fc_std = torch.mean(all_c, 0), self.torch_cov(all_c, rowvar=False) self.trainer.stats['FID']['c_fake'] = self.calc_fid( fc_mu, fc_std, self.c_mu, self.c_std) if self.calc_zp: all_zp = torch.cat(all_zp, dim=0) fzp_mu, fzp_std = torch.mean(all_zp, 0), self.torch_cov(all_zp, rowvar=False) self.trainer.stats['FID']['zp_fake'] = self.calc_fid( fzp_mu, fzp_std, self.zp_mu, self.zp_std) if self.calc_c: all_cp = torch.cat(all_cp, dim=0) fcp_mu, fcp_std = torch.mean(all_cp, 0), self.torch_cov(all_cp, rowvar=False) self.trainer.stats['FID']['cp_fake'] = self.calc_fid( fcp_mu, fcp_std, self.cp_mu, self.cp_std) self.trainer.stats['FID']['epoch'] = epoch_index
def main(summary): train_dataset, val_dataset = EEGDataset.from_config( validation_ratio=hp.validation_ratio, validation_seed=hp.validation_seed, dir_path='./data/prepared_eegs_mat_th5', data_sampling_freq=220, start_sampling_freq=1, end_sampling_freq=60, start_seq_len=32, num_channels=17, return_long=False) train_dataloader = DataLoader(train_dataset, batch_size=hp.batch_size, num_workers=0, drop_last=True) val_dataloader = DataLoader(val_dataset, batch_size=hp.batch_size, num_workers=0, drop_last=False, pin_memory=True) network = cudize( Network(train_dataset.num_channels, bidirectional=False, contextualizer_num_layers=hp.contextualizer_num_layers, contextualizer_dropout=hp.contextualizer_dropout, use_transformer=hp.use_transformer, prediction_k=hp.prediction_k * (hp.prediction_loss_weight != 0.0), have_global=(hp.global_loss_weight != 0.0), have_local=(hp.local_loss_weight != 0.0), residual_encoder=hp.residual_encoder, sinc_encoder=hp.sinc_encoder)) num_parameters = num_params(network) print('num_parameters', num_parameters) if hp.use_bert_adam: network_optimizer = BertAdam(network.parameters(), lr=hp.lr, weight_decay=hp.weight_decay, warmup=0.2, t_total=hp.epochs * len(train_dataloader), schedule='warmup_linear') else: network_optimizer = Adam(network.parameters(), lr=hp.lr, weight_decay=hp.weight_decay) if hp.use_scheduler: scheduler = ReduceLROnPlateau(network_optimizer, patience=3, verbose=True) best_val_loss = float('inf') for epoch in trange(hp.epochs): for training, data_loader in zip((False, True), (val_dataloader, train_dataloader)): if training: if epoch == hp.epochs - 1: break network.train() else: network.eval() total_network_loss = 0.0 total_prediction_loss = 0.0 total_global_loss = 0.0 total_local_loss = 0.0 total_global_accuracy = 0.0 total_local_accuracy = 0.0 total_k_pred_acc = {} total_pred_acc = 0.0 total_count = 0 with torch.set_grad_enabled(training): for batch in data_loader: x = cudize(batch['x']) network_return = network.forward(x) network_loss = hp.prediction_loss_weight * network_return.losses.prediction_ network_loss = network_loss + hp.global_loss_weight * network_return.losses.global_ network_loss = network_loss + hp.local_loss_weight * network_return.losses.local_ bs = x.size(0) total_count += bs total_network_loss += network_loss.item() * bs total_prediction_loss += network_return.losses.prediction_.item( ) * bs total_global_loss += network_return.losses.global_.item( ) * bs total_local_loss += network_return.losses.local_.item( ) * bs total_global_accuracy += network_return.accuracies.global_ * bs total_local_accuracy += network_return.accuracies.local_ * bs dict_add(total_k_pred_acc, network_return.accuracies.prediction_, bs) len_pred = len(network_return.accuracies.prediction_) if len_pred > 0: total_pred_acc += sum( network_return.accuracies.prediction_.values( )) / len_pred * bs if training: network_optimizer.zero_grad() network_loss.backward() network_optimizer.step() metrics = dict(net_loss=total_network_loss) if network.prediction_loss_network.k > 0 and hp.prediction_loss_weight != 0: metrics.update( dict(avg_prediction_acc=total_pred_acc, prediction_loss=total_prediction_loss, k_prediction_acc=total_k_pred_acc)) if hp.global_loss_weight != 0: metrics.update( dict(global_loss=total_global_loss, global_acc=total_global_accuracy)) if hp.local_loss_weight != 0: metrics.update( dict(local_loss=total_local_loss, local_acc=total_local_accuracy)) divide_dict(metrics, total_count) if not training and hp.use_scheduler: scheduler.step(metrics['net_loss']) if summary: print('train' if training else 'validation', epoch, metrics['net_loss']) else: print('train' if training else 'validation', epoch) print(json.dumps(metrics, indent=4)) if not training and (metrics['net_loss'] < best_val_loss): best_val_loss = metrics['net_loss'] print('update best to', best_val_loss) torch.save(network.state_dict(), 'best_network.pth')
def main(num_samples): # NOTE, this are model dependent and it's far better to read them from a yml file skip_depth = 6 if test_mode else 0 progression_scale_up = EEGDataset.progression_scale_up progression_scale_down = EEGDataset.progression_scale_down train_dataset, val_dataset = ThinEEGDataset.from_config( validation_ratio=hp.validation_ratio, stride=hp.ds_stride, dir_path='./data/tuh1/', num_channels=hp.num_channels) real_ndb = None final_result = {} if test_mode: models_zip = zip([hp], ['default_-7.531810902716695']) else: models_zip = zip([hp, prediction_hp, local_hp], [ 'default_-7.531810902716695', 'prediction_-2.450593529493725', 'local_-0.5012809535156667' ]) for current_hp, model_address in models_zip: network = Network( train_dataset.num_channels, generate_long_sequence=current_hp.generate_long_sequence, pooling=current_hp.pool_or_stride == 'pool', encoder_dropout=current_hp.encoder_dropout, use_sinc_encoder=current_hp.use_sinc_encoder, use_shared_sinc=current_hp.use_shared_sinc, bidirectional=current_hp.bidirectional, contextualizer_num_layers=current_hp.contextualizer_num_layers, contextualizer_dropout=current_hp.contextualizer_dropout, use_transformer=current_hp.use_transformer, causal_prediction=current_hp.causal_prediction, prediction_k=current_hp.prediction_k, encoder_activation=current_hp.encoder_activation, tiny_encoder=current_hp.tiny_encoder) network.load_state_dict( torch.load('./results/cpc_trained/' + model_address + '.pth', map_location='cpu')) network = cudize(network.eval()) collected_results = [] print('loaded', model_address) for i in range(2 if test_mode else 10): # for stability checks if test_mode: print(model_address, 'run #{}'.format(i)) real_stats, real_ndb = calculate_stats(train_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'normal', current_hp, real_ndb, None) if test_mode: print('real stats calculated') val_stats, _ = calculate_stats(val_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'validation', current_hp, real_ndb, real_stats) if test_mode: print('val stats calculated') permuted_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'permute', current_hp, real_ndb, real_stats) if test_mode: print('permuted stats calculated') shifted_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'shift', current_hp, real_ndb, real_stats) if test_mode: print('shifted stats calculated') concatenated_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'concat', current_hp, real_ndb, real_stats) if test_mode: print('concatenated stats calculated') tiny_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'tiny', current_hp, real_ndb, real_stats) if test_mode: print('tiny stats calculated') zeroed_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'zero', current_hp, real_ndb, real_stats) if test_mode: print('zeroed stats calculated') noised_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, progression_scale_down, skip_depth, num_samples, 'noise', current_hp, real_ndb, real_stats) if test_mode: print('noised stats calculated') collected_results.append({ **real_stats, **shifted_stats, **concatenated_stats, **tiny_stats, **zeroed_stats, **noised_stats }) # TODO (over time and different truncation threshold) # normal_generated_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, # progression_scale_down, skip_depth, num_samples, 'generated', # current_hp, real_ndb, real_stats) # averaged_generated_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, # progression_scale_down, skip_depth, num_samples, 'averaged', # current_hp, real_ndb, real_stats) # truncated_generated_stats, _ = calculate_stats(train_dataset, network, progression_scale_up, # progression_scale_down, skip_depth, num_samples, # 'truncation', current_hp, real_ndb, real_stats) # calc std for each stats ([{key(mode_n): {key(seq_len): {meter: value}}}]) final_result[model_address] = { mode: { seq_len: { meter: (np.mean([ collected_results[j][mode][seq_len][meter] for j in range(2 if test_mode else 10) ]), np.std([ collected_results[j][mode][seq_len][meter] for j in range(2 if test_mode else 10) ])) for meter in { 'prediction_loss', 'prediction_acc', 'global_loss', 'global_acc', 'local_loss', 'local_acc', 'c_fid_max_seq_len', 'z_fid_max_seq_len', 'cp_fid_max_seq_len', 'zp_fid_max_seq_len', 'net_loss', 'ndb_score', 'ndb_js', 'c_fid', 'z_fid', 'zp_fid', 'cp_fid', 'z_vis_max_seq_len', 'z_vis', 'c_vis_max_seq_len', 'c_vis', 'zp_vis_max_seq_len', 'zp_vis', 'cp_vis_max_seq_len', 'cp_vis' } } for seq_len in collected_results[0][mode].keys() } for mode in collected_results[0].keys() } return final_result
def calculate_network_stats(x, net: Network, scale_up, scale_down, skip_depth, current_hp, real_ndb, real_stats): max_seq_len = x.size(2) seq_lens = [max_seq_len] for i in reversed(range(skip_depth, len(scale_up))): seq_lens = [int(seq_lens[0] * scale_down[i] / scale_up[i])] + seq_lens seq_lens = reversed(seq_lens) stats = {} with torch.no_grad(): for seq_len in seq_lens: this_x = resample_signal(x, max_seq_len, seq_len, True) ndb, js = real_ndb[seq_len].evaluate( this_x.view(this_x.size(0), -1).numpy()) this_x = resample_signal(this_x, seq_len, max_seq_len, True) # resample to give to the net total_network_loss = 0.0 total_prediction_loss = 0.0 total_global_discriminator_loss = 0.0 total_local_discriminator_loss = 0.0 total_global_accuracy_one = 0.0 total_global_accuracy_two = 0.0 total_local_accuracy_one = 0.0 total_local_accuracy_two = 0.0 total_pred_acc = {} total_count = 0 all_z = [] all_c = [] all_zp = [] all_cp = [] for i in range(x.size(0) // 128): prediction_loss, global_discriminator_loss, local_discriminator_loss, cp, global_accuracy, local_accuracy, pred_accuracy, z, c, zp = net.complete_forward( cudize(this_x[i * 128:(i + 1) * 128])) global_accuracy_one, global_accuracy_two = global_accuracy local_accuracy_one, local_accuracy_two = local_accuracy network_loss = current_hp.prediction_loss_weight * prediction_loss + current_hp.global_loss_weight * global_discriminator_loss + current_hp.local_loss_weight * local_discriminator_loss this_batch_size = this_x[i * 128:(i + 1) * 128].size(0) total_count += this_batch_size total_network_loss += network_loss.item() * this_batch_size total_prediction_loss += prediction_loss.item( ) * this_batch_size total_global_discriminator_loss += global_discriminator_loss.item( ) * this_batch_size total_local_discriminator_loss += local_discriminator_loss.item( ) * this_batch_size total_global_accuracy_one += global_accuracy_one * this_batch_size total_global_accuracy_two += global_accuracy_two * this_batch_size total_local_accuracy_one += local_accuracy_one * this_batch_size total_local_accuracy_two += local_accuracy_two * this_batch_size dict_add(total_pred_acc, pred_accuracy, this_batch_size) # subsample z and c (assumes T = 32) all_z.append(z[:8].contiguous().view(-1, z.size(1)).cpu()) all_c.append(c[:8].contiguous().view(-1, c.size(1)).cpu()) all_zp.append(zp.cpu()) all_cp.append(cp.cpu()) total_global_accuracy_one /= total_count total_global_accuracy_two /= total_count total_local_accuracy_one /= total_count total_local_accuracy_two /= total_count divide_dict(total_pred_acc, total_count) total_prediction_loss /= total_count total_pred_acc = merge_pred_accs( total_pred_acc, net.prediction_loss_network.k, net.prediction_loss_network.bidirectional) total_global_discriminator_loss /= total_count total_global_accuracy = (total_global_accuracy_one + total_global_accuracy_two) / 2 total_local_discriminator_loss /= total_count total_local_accuracy = (total_local_accuracy_one + total_local_accuracy_two) / 2 total_network_loss /= total_count metrics = dict(prediction_loss=total_prediction_loss, prediction_acc=total_pred_acc, global_loss=total_global_discriminator_loss, global_acc=total_global_accuracy, local_loss=total_local_discriminator_loss, local_acc=total_local_accuracy, net_loss=total_network_loss, ndb_score=ndb, ndb_js=js) if test_mode: print(metrics) all_z = torch.cat(all_z, dim=0) all_c = torch.cat(all_c, dim=0) all_zp = torch.cat(all_zp, dim=0) all_cp = torch.cat(all_cp, dim=0) for (name, all_name) in zip(['z', 'c', 'zp', 'cp'], [all_z, all_c, all_zp, all_cp]): all_name = cudize(all_name) mean_cov = calc_mean_cov(name, all_name) if real_stats is None: trained_vin = train_virtual_inception_network(all_name) metrics.update({ **mean_cov, name + '_fid': 0.0, name + '_virtual_inception_network': trained_vin }) else: trained_vin = real_stats['normal'][seq_len][ name + '_virtual_inception_network'] metrics.update({ name + '_fid': FidCalculator.calc_fid( real_stats['normal'][seq_len][name + '_mean'], real_stats['normal'][seq_len][name + '_cov'], mean_cov[name + '_mean'], mean_cov[name + '_cov']) }) metrics.update({ name + '_vis': calculate_inception_score(all_name, trained_vin) }) if seq_len == max_seq_len: metrics.update({ name + '_fid_max_seq_len': metrics[name + '_fid'], name + '_vis_max_seq_len': metrics[name + '_vis'] }) else: if real_stats is None: working_real_stats = stats[max_seq_len] else: working_real_stats = real_stats['normal'][max_seq_len] max_mean = working_real_stats[name + '_mean'] max_cov = working_real_stats[name + '_cov'] max_vin = working_real_stats[name + '_virtual_inception_network'] metrics.update({ name + '_fid_max_seq_len': FidCalculator.calc_fid(max_mean, max_cov, mean_cov[name + '_mean'], mean_cov[name + '_cov']), name + '_vis_max_seq_len': calculate_inception_score(all_name, max_vin) }) stats[seq_len] = metrics return stats
def collate_real(batch): return cudize(default_collate(batch))
def get_zero(batch_size): global zero if zero is None or batch_size != zero.size(0): zero = cudize(torch.zeros(batch_size)) return zero
def get_one(batch_size): global one if one is None or batch_size != one.size(0): one = cudize(torch.ones(batch_size)) return one
def get_mixing_factor(batch_size): global mixing_factors if mixing_factors is None or batch_size != mixing_factors.size(0): mixing_factors = cudize(torch.FloatTensor(batch_size, 1, 1)) mixing_factors.uniform_() return mixing_factors