def test(dataset, data_split, label_split, model, logger, epoch): with torch.no_grad(): metric = Metric() model.train(False) for m in range(cfg['num_users']): data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test'] for i, input in enumerate(data_loader): input = collate(input) input_size = input['img'].size(0) input['label_split'] = torch.tensor(label_split[m]) input = to_device(input, cfg['device']) output = model(input) output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], input, output) logger.append(evaluation, 'test', input_size) data_loader = make_data_loader({'test': dataset})['test'] for i, input in enumerate(data_loader): input = collate(input) input_size = input['img'].size(0) input = to_device(input, cfg['device']) output = model(input) output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss'] evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], input, output) logger.append(evaluation, 'test', input_size) info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]} logger.append(info, 'test', mean=False) logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global']) return
def runExperiment(): seed = int(cfg['model_tag'].split('_')[0]) torch.manual_seed(seed) torch.cuda.manual_seed(seed) dataset = fetch_dataset(cfg['data_name'], cfg['subset']) process_dataset(dataset['train']) if cfg['raw']: data_loader = make_data_loader(dataset)['train'] metric = Metric() img = [] for i, input in enumerate(data_loader): input = collate(input) img.append(input['img']) img = torch.cat(img, dim=0) output = {'img': img} evaluation = metric.evaluate(cfg['metric_name']['test'], None, output) is_result, fid_result = evaluation['InceptionScore'], evaluation['FID'] print('Inception Score ({}): {}'.format(cfg['data_name'], is_result)) print('FID ({}): {}'.format(cfg['data_name'], fid_result)) save(is_result, './output/result/is_generated_{}.npy'.format(cfg['data_name']), mode='numpy') save(fid_result, './output/result/fid_generated_{}.npy'.format(cfg['data_name']), mode='numpy') else: generated = np.load('./output/npy/generated_{}.npy'.format( cfg['model_tag']), allow_pickle=True) test(generated) return
def train(data_loader, model, optimizer, logger, epoch): metric = Metric() model.train(True) for i, input in enumerate(data_loader): start_time = time.time() input = collate(input) input_size = len(input['img']) input = to_device(input, config.PARAM['device']) model.zero_grad() output = model(input) output['loss'] = output['loss'].mean() if config.PARAM['world_size'] > 1 else output['loss'] output['loss'].backward() optimizer.step() if i % int((len(data_loader) * config.PARAM['log_interval']) + 1) == 0: batch_time = time.time() - start_time lr = optimizer.param_groups[0]['lr'] epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1))) exp_finished_time = epoch_finished_time + datetime.timedelta( seconds=round((config.PARAM['num_epochs'] - epoch) * batch_time * len(data_loader))) info = {'info': ['Model: {}'.format(config.PARAM['model_tag']), 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)), 'Learning rate: {}'.format(lr), 'Epoch Finished Time: {}'.format(epoch_finished_time), 'Experiment Finished Time: {}'.format(exp_finished_time)]} logger.append(info, 'train', mean=False) evaluation = metric.evaluate(config.PARAM['metric_names']['train'], input, output) logger.append(evaluation, 'train', n=input_size) logger.write('train', config.PARAM['metric_names']['train']) return
def test(data_loader, ae, model, logger, epoch): with torch.no_grad(): metric = Metric() ae.train(False) model.train(False) for i, input in enumerate(data_loader): input = collate(input) input_size = input['img'].size(0) input = to_device(input, cfg['device']) _, _, input['img'] = ae.encode(input['img']) input['img'] = input['img'].detach() output = model(input) output['loss'] = output['loss'].mean( ) if cfg['world_size'] > 1 else output['loss'] evaluation = metric.evaluate(cfg['metric_name']['test'], input, output) logger.append(evaluation, 'test', input_size) logger.append(evaluation, 'test') info = { 'info': [ 'Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.) ] } logger.append(info, 'test', mean=False) logger.write('test', cfg['metric_name']['test']) return
def runExperiment(): seed = int(cfg['model_tag'].split('_')[0]) torch.manual_seed(seed) torch.cuda.manual_seed(seed) dataset = fetch_dataset(cfg['data_name'], cfg['subset']) process_dataset(dataset['train']) if cfg['raw']: data_loader = make_data_loader(dataset)['train'] metric = Metric() img, label = [], [] for i, input in enumerate(data_loader): input = collate(input) img.append(input['img']) label.append(input['label']) img = torch.cat(img, dim=0) label = torch.cat(label, dim=0) output = {'img': img, 'label': label} evaluation = metric.evaluate(cfg['metric_name']['test'], None, output) dbi_result = evaluation['DBI'] print('Davies-Bouldin Index ({}): {}'.format(cfg['data_name'], dbi_result)) save(dbi_result, './output/result/dbi_created_{}.npy'.format(cfg['data_name']), mode='numpy') else: created = np.load('./output/npy/created_{}.npy'.format( cfg['model_tag']), allow_pickle=True) test(created) return
def test(data_loader, model, logger, epoch): with torch.no_grad(): metric = Metric() model.train(False) for i, input in enumerate(data_loader): input = collate(input) input_size = input['img'].size(0) input = to_device(input, cfg['device']) output = model(input) output['loss'] = output['loss'].mean( ) if cfg['world_size'] > 1 else output['loss'] evaluation = metric.evaluate(cfg['metric_name']['test'], input, output) logger.append(evaluation, 'test', input_size) logger.append(evaluation, 'test') info = { 'info': [ 'Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.) ] } logger.append(info, 'test', mean=False) logger.write('test', cfg['metric_name']['test']) if cfg['show']: input['reconstruct'] = True input['z'] = output['z'] output = model.reverse(input) save_img(input['img'][:100], './output/vis/input_{}.png'.format(cfg['model_tag']), range=(-1, 1)) save_img(output['img'][:100], './output/vis/output_{}.png'.format(cfg['model_tag']), range=(-1, 1)) return
def stats(data_loader, model): with torch.no_grad(): model.train(True) for i, input in enumerate(data_loader): input = collate(input) input = to_device(input, cfg['device']) model(input) return
def stats(dataset, model): with torch.no_grad(): data_loader = make_data_loader({'train': dataset})['train'] model.train(True) for i, input in enumerate(data_loader): input = collate(input) input = to_device(input, cfg['device']) model(input) return
def test(data_loader): with torch.no_grad(): generated = [] for i, input in enumerate(data_loader): input = collate(input) generated.append(input['img']) generated = torch.cat(generated) generated = (generated + 1) / 2 * 255 save(generated.numpy(), './output/npy/generated_0_{}.npy'.format(cfg['data_name']), mode='numpy') return
def stats(dataset, model): with torch.no_grad(): test_model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"])' .format(cfg['model_name'])) test_model.load_state_dict(model.state_dict(), strict=False) data_loader = make_data_loader({'train': dataset})['train'] test_model.train(True) for i, input in enumerate(data_loader): input = collate(input) input = to_device(input, cfg['device']) test_model(input) return test_model
def train(data_loader, ae, model, optimizer, logger, epoch): metric = Metric() ae.train(False) model.train(True) start_time = time.time() for i, input in enumerate(data_loader): input = collate(input) input_size = input['img'].size(0) input = to_device(input, cfg['device']) with torch.no_grad(): _, _, input['img'] = ae.encode(input['img']) input['img'] = input['img'].detach() optimizer.zero_grad() output = model(input) output['loss'] = output['loss'].mean( ) if cfg['world_size'] > 1 else output['loss'] output['loss'].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) optimizer.step() evaluation = metric.evaluate(cfg['metric_name']['train'], input, output) logger.append(evaluation, 'train', n=input_size) if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0: batch_time = (time.time() - start_time) / (i + 1) lr = optimizer.param_groups[0]['lr'] epoch_finished_time = datetime.timedelta( seconds=round(batch_time * (len(data_loader) - i - 1))) exp_finished_time = epoch_finished_time + datetime.timedelta( seconds=round((cfg['num_epochs'] - epoch) * batch_time * len(data_loader))) info = { 'info': [ 'Model: {}'.format(cfg['model_tag']), 'Train Epoch: {}({:.0f}%)'.format( epoch, 100. * i / len(data_loader)), 'Learning rate: {}'.format(lr), 'Epoch Finished Time: {}'.format(epoch_finished_time), 'Experiment Finished Time: {}'.format(exp_finished_time) ] } logger.append(info, 'train', mean=False) logger.write('train', cfg['metric_name']['train']) return
def train(self, local_parameters, lr, logger): metric = Metric() model = eval('models.{}(model_rate=self.model_rate).to(cfg["device"])'.format(cfg['model_name'])) model.load_state_dict(local_parameters) model.train(True) optimizer = make_optimizer(model, lr) for local_epoch in range(1, cfg['num_epochs']['local'] + 1): for i, input in enumerate(self.data_loader): input = collate(input) input_size = input['img'].size(0) input['label_split'] = torch.tensor(self.label_split) input = to_device(input, cfg['device']) optimizer.zero_grad() output = model(input) output['loss'].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) optimizer.step() evaluation = metric.evaluate(cfg['metric_name']['train']['Local'], input, output) logger.append(evaluation, 'train', n=input_size) local_parameters = model.state_dict() return local_parameters
def FID(img): with torch.no_grad(): batch_size = 32 cfg['batch_size']['train'] = batch_size dataset = fetch_dataset(cfg['data_name'], cfg['subset'], verbose=False) real_data_loader = make_data_loader(dataset)['train'] generated_data_loader = DataLoader(img, batch_size=batch_size) if cfg['data_name'] in ['COIL100', 'Omniglot']: model = models.classifier().to(cfg['device']) model_tag = ['0', cfg['data_name'], cfg['subset'], 'classifier'] model_tag = '_'.join(filter(None, model_tag)) checkpoint = load( './metrics_tf/res/classifier/{}_best.pt'.format(model_tag)) model.load_state_dict(checkpoint['model_dict']) model.train(False) real_feature = [] for i, input in enumerate(real_data_loader): input = collate(input) input = to_device(input, cfg['device']) real_feature_i = model.feature(input) real_feature.append(real_feature_i.cpu().numpy()) real_feature = np.concatenate(real_feature, axis=0) generated_feature = [] for i, input in enumerate(generated_data_loader): input = { 'img': input, 'label': input.new_zeros(input.size(0)).long() } input = to_device(input, cfg['device']) generated_feature_i = model.feature(input) generated_feature.append(generated_feature_i.cpu().numpy()) generated_feature = np.concatenate(generated_feature, axis=0) else: model = inception_v3(pretrained=True, transform_input=False).to(cfg['device']) up = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False) model.feature = nn.Sequential(*[ up, model.Conv2d_1a_3x3, model.Conv2d_2a_3x3, model.Conv2d_2b_3x3, nn.MaxPool2d(kernel_size=3, stride=2), model.Conv2d_3b_1x1, model.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2), model.Mixed_5b, model.Mixed_5c, model.Mixed_5d, model.Mixed_6a, model.Mixed_6b, model.Mixed_6c, model.Mixed_6d, model.Mixed_6e, model.Mixed_7a, model.Mixed_7b, model.Mixed_7c, nn.AdaptiveAvgPool2d(1), nn.Flatten() ]) model.train(False) real_feature = [] for i, input in enumerate(real_data_loader): input = collate(input) input = to_device(input, cfg['device']) real_feature_i = model.feature(input['img']) real_feature.append(real_feature_i.cpu().numpy()) real_feature = np.concatenate(real_feature, axis=0) generated_feature = [] for i, input in enumerate(generated_data_loader): input = to_device(input, cfg['device']) generated_feature_i = model.feature(input) generated_feature.append(generated_feature_i.cpu().numpy()) generated_feature = np.concatenate(generated_feature, axis=0) mu1 = np.mean(real_feature, axis=0) sigma1 = np.cov(real_feature, rowvar=False) mu2 = np.mean(generated_feature, axis=0) sigma2 = np.cov(generated_feature, rowvar=False) mu1 = np.atleast_1d(mu1) mu2 = np.atleast_1d(mu2) sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" diff = mu1 - mu2 # product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): offset = np.eye(sigma1.shape[0]) * 1e-6 covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) # numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) raise ValueError("Imaginary component {}".format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) fid = diff.dot(diff) + np.trace(sigma1) + np.trace( sigma2) - 2 * tr_covmean fid = fid.item() return fid
def summarize(data_loader, model): def register_hook(module): def hook(module, input, output): module_name = str(module.__class__.__name__) if module_name not in summary['count']: summary['count'][module_name] = 1 else: summary['count'][module_name] += 1 key = str(hash(module)) if key not in summary['module']: summary['module'][key] = OrderedDict() summary['module'][key]['module_name'] = '{}_{}'.format(module_name, summary['count'][module_name]) summary['module'][key]['input_size'] = [] summary['module'][key]['output_size'] = [] summary['module'][key]['params'] = {} summary['module'][key]['flops'] = make_flops(module, input, output) input_size, output_size = make_size(input, output) summary['module'][key]['input_size'].append(input_size) summary['module'][key]['output_size'].append(output_size) for name, param in module.named_parameters(): if param.requires_grad: if name in ['weight', 'in_proj_weight', 'out_proj.weight']: if name not in summary['module'][key]['params']: summary['module'][key]['params'][name] = {} summary['module'][key]['params'][name]['size'] = list(param.size()) summary['module'][key]['coordinates'] = [] summary['module'][key]['params'][name]['mask'] = torch.zeros( summary['module'][key]['params'][name]['size'], dtype=torch.long, device=cfg['device']) elif name in ['bias', 'in_proj_bias', 'out_proj.bias']: if name not in summary['module'][key]['params']: summary['module'][key]['params'][name] = {} summary['module'][key]['params'][name]['size'] = list(param.size()) summary['module'][key]['params'][name]['mask'] = torch.zeros( summary['module'][key]['params'][name]['size'], dtype=torch.long, device=cfg['device']) else: continue if len(summary['module'][key]['params']) == 0: return for name in summary['module'][key]['params']: summary['module'][key]['params'][name]['mask'] += 1 return if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) \ and not isinstance(module, nn.ModuleDict) and module != model: hooks.append(module.register_forward_hook(hook)) return run_mode = True summary = OrderedDict() summary['module'] = OrderedDict() summary['count'] = OrderedDict() hooks = [] model.train(run_mode) model.apply(register_hook) if cfg['data_name'] in ['MNIST', 'CIFAR10']: for i, input in enumerate(data_loader): input = collate(input) input = to_device(input, cfg['device']) model(input) break elif cfg['data_name'] in ['WikiText2']: dataset = BatchDataset(data_loader.dataset, cfg['bptt']) for i, input in enumerate(dataset): input = to_device(input, cfg['device']) model(input) break else: raise ValueError('Not valid data name') for h in hooks: h.remove() summary['total_num_params'] = 0 summary['total_num_flops'] = 0 for key in summary['module']: num_params = 0 num_flops = 0 for name in summary['module'][key]['params']: num_params += (summary['module'][key]['params'][name]['mask'] > 0).sum().item() num_flops += summary['module'][key]['flops'] summary['total_num_params'] += num_params summary['total_num_flops'] += num_flops summary['total_space'] = summary['total_num_params'] * 32. / 8 / (1024 ** 2.) return summary
def summarize(data_loader, model, ae=None): def register_hook(module): def hook(module, input, output): module_name = str(module.__class__.__name__) if module_name not in summary['count']: summary['count'][module_name] = 1 else: summary['count'][module_name] += 1 key = str(hash(module)) if key not in summary['module']: summary['module'][key] = OrderedDict() summary['module'][key]['module_name'] = '{}_{}'.format(module_name, summary['count'][module_name]) summary['module'][key]['input_size'] = [] summary['module'][key]['output_size'] = [] summary['module'][key]['params'] = {} input_size = make_size(input) output_size = make_size(output) summary['module'][key]['input_size'].append(input_size) summary['module'][key]['output_size'].append(output_size) for name, param in module.named_parameters(): if param.requires_grad: if name in ['weight', 'weight_orig']: if name not in summary['module'][key]['params']: summary['module'][key]['params']['weight'] = {} summary['module'][key]['params']['weight']['size'] = list(param.size()) summary['module'][key]['coordinates'] = [] summary['module'][key]['params']['weight']['mask'] = torch.zeros( summary['module'][key]['params']['weight']['size'], dtype=torch.long, device=cfg['device']) elif name == 'bias': if name not in summary['module'][key]['params']: summary['module'][key]['params']['bias'] = {} summary['module'][key]['params']['bias']['size'] = list(param.size()) summary['module'][key]['params']['bias']['mask'] = torch.zeros( summary['module'][key]['params']['bias']['size'], dtype=torch.long, device=cfg['device']) else: continue if len(summary['module'][key]['params']) == 0: return if 'weight' in summary['module'][key]['params']: weight_size = summary['module'][key]['params']['weight']['size'] summary['module'][key]['coordinates'].append( [torch.arange(weight_size[i], device=cfg['device']) for i in range(len(weight_size))]) else: raise ValueError('Not valid parametrized module') for name in summary['module'][key]['params']: coordinates = summary['module'][key]['coordinates'][-1] if name == 'weight': if len(coordinates) == 1: summary['module'][key]['params'][name]['mask'][coordinates[0]] += 1 elif len(coordinates) >= 2: summary['module'][key]['params'][name]['mask'][ coordinates[0].view(-1, 1), coordinates[1].view(1, -1),] += 1 else: raise ValueError('Not valid coordinates dimension') elif name == 'bias': if len(coordinates) == 1: summary['module'][key]['params'][name]['mask'] += 1 elif len(coordinates) >= 2: summary['module'][key]['params'][name]['mask'] += 1 else: raise ValueError('Not valid coordinates dimension') else: raise ValueError('Not valid parameters type') return if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) \ and not isinstance(module, nn.ModuleDict) and module != model: hooks.append(module.register_forward_hook(hook)) return run_mode = True summary = OrderedDict() summary['module'] = OrderedDict() summary['count'] = OrderedDict() hooks = [] model.train(run_mode) model.apply(register_hook) for i, input in enumerate(data_loader): input = collate(input) input = to_device(input, cfg['device']) if ae is not None: with torch.no_grad(): _, _, input['img'] = ae.encode(input['img']) input['img'] = input['img'].detach() model(input) break for h in hooks: h.remove() summary['total_num_param'] = 0 for key in summary['module']: num_params = 0 for name in summary['module'][key]['params']: num_params += (summary['module'][key]['params'][name]['mask'] > 0).sum().item() summary['total_num_param'] += num_params summary['total_space_param'] = abs(summary['total_num_param'] * 32. / 8 / (1024 ** 2.)) return summary
def runExperiment(): seed = int(cfg['model_tag'].split('_')[0]) torch.manual_seed(seed) torch.cuda.manual_seed(seed) dataset = fetch_dataset(cfg['data_name'], cfg['subset']) process_dataset(dataset['train']) data_loader = make_data_loader(dataset) model = eval('models.{}().to(cfg["device"])'.format(cfg['model_name'])) init_batches = {'img': [], 'label': []} with torch.no_grad(): for input in islice(data_loader['train'], None, cfg['num_init_batches']): for k in init_batches: init_batches[k].extend(input[k]) init_batches = collate(init_batches) init_batches = to_device(init_batches, cfg['device']) model(init_batches) optimizer = make_optimizer(model) scheduler = make_scheduler(optimizer) if cfg['resume_mode'] == 1: last_epoch, model, optimizer, scheduler, logger = resume( model, cfg['model_tag'], optimizer, scheduler) elif cfg['resume_mode'] == 2: last_epoch = 1 _, model, _, _, _ = resume(model, cfg['model_tag']) logger_path = 'output/runs/{}_{}'.format( cfg['model_tag'], datetime.datetime.now().strftime('%b%d_%H-%M-%S')) logger = Logger(logger_path) else: last_epoch = 1 logger_path = 'output/runs/train_{}_{}'.format( cfg['model_tag'], datetime.datetime.now().strftime('%b%d_%H-%M-%S')) logger = Logger(logger_path) if cfg['world_size'] > 1: model = torch.nn.DataParallel(model, device_ids=list(range( cfg['world_size']))) for epoch in range(last_epoch, cfg['num_epochs'] + 1): logger.safe(True) train(data_loader['train'], model, optimizer, logger, epoch) test(data_loader['train'], model, logger, epoch) if cfg['scheduler_name'] == 'ReduceLROnPlateau': scheduler.step( metrics=logger.mean['test/{}'.format(cfg['pivot_metric'])]) else: scheduler.step() logger.safe(False) model_state_dict = model.module.state_dict( ) if cfg['world_size'] > 1 else model.state_dict() save_result = { 'cfg': cfg, 'epoch': epoch + 1, 'model_dict': model_state_dict, 'optimizer_dict': optimizer.state_dict(), 'scheduler_dict': scheduler.state_dict(), 'logger': logger } save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag'])) if cfg['pivot'] > logger.mean['test/{}'.format(cfg['pivot_metric'])]: cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])] shutil.copy( './output/model/{}_checkpoint.pt'.format(cfg['model_tag']), './output/model/{}_best.pt'.format(cfg['model_tag'])) logger.reset() logger.safe(False) return
def train(data_loader, model, optimizer, logger, epoch): metric = Metric() model.train(True) start_time = time.time() for i, input in enumerate(data_loader): input = collate(input) input_size = input['img'].size(0) input = to_device(input, cfg['device']) ############################ # (1) Update D network ########################### for _ in range(cfg['iter']['discriminator']): # train with real optimizer['discriminator'].zero_grad() optimizer['generator'].zero_grad() D_x = model.discriminate(input['img'], input[cfg['subset']]) # train with fake z1 = torch.randn(input['img'].size(0), cfg['gan']['latent_size'], device=cfg['device']) generated = model.generate(input[cfg['subset']], z1) D_G_z1 = model.discriminate(generated.detach(), input[cfg['subset']]) if cfg['loss_type'] == 'BCE': D_loss = torch.nn.functional.binary_cross_entropy_with_logits( D_x, torch.ones((input['img'].size(0), 1), device=cfg['device'])) + \ torch.nn.functional.binary_cross_entropy_with_logits( D_G_z1, torch.zeros((input['img'].size(0), 1), device=cfg['device'])) elif cfg['loss_type'] == 'Hinge': D_loss = torch.nn.functional.relu(1.0 - D_x).mean() + torch.nn.functional.relu(1.0 + D_G_z1).mean() else: raise ValueError('Not valid loss type') D_loss.backward() optimizer['discriminator'].step() ############################ # (2) Update G network ########################### for _ in range(cfg['iter']['generator']): optimizer['discriminator'].zero_grad() optimizer['generator'].zero_grad() z2 = torch.randn(input['img'].size(0), cfg['gan']['latent_size'], device=cfg['device']) generated = model.generate(input[cfg['subset']], z2) D_G_z2 = model.discriminate(generated, input[cfg['subset']]) if cfg['loss_type'] == 'BCE': G_loss = torch.nn.functional.binary_cross_entropy_with_logits( D_G_z2, torch.ones((input['img'].size(0), 1), device=cfg['device'])) elif cfg['loss_type'] == 'Hinge': G_loss = -D_G_z2.mean() else: raise ValueError('Not valid loss type') G_loss.backward() optimizer['generator'].step() output = {'loss': abs(D_loss - G_loss), 'loss_D': D_loss, 'loss_G': G_loss} evaluation = metric.evaluate(cfg['metric_name']['train'], input, output) logger.append(evaluation, 'train', n=input_size) if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0: batch_time = (time.time() - start_time) / (i + 1) generator_lr, discriminator_lr = optimizer['generator'].param_groups[0]['lr'], \ optimizer['discriminator'].param_groups[0]['lr'] epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1))) exp_finished_time = epoch_finished_time + datetime.timedelta( seconds=round((cfg['num_epochs'] - epoch) * batch_time * len(data_loader))) info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)), 'Learning rate : (G: {}, D: {})'.format(generator_lr, discriminator_lr), 'Epoch Finished Time: {}'.format(epoch_finished_time), 'Experiment Finished Time: {}'.format(exp_finished_time)]} logger.append(info, 'train', mean=False) logger.write('train', cfg['metric_name']['train']) return