def train_test_once(ModelHelper, PreprocessHelper, preprocess_param): seed_everything(param['base_seed']) if not os.path.exists(param['output']): os.makedirs(param['output']) kfold = MultilabelStratifiedKFold(n_splits=param['n_folds'], random_state=param['base_seed'], shuffle=True) x_train, y_train, y_train_with_non_scored, x_test, _, _ = \ PreprocessHelper(root_dir=param['root_dir'], out_data_dir=param['out_data_dir'], is_train=param['is_train_data'], read_directly=param['read_directly']).process( preprocess_param=preprocess_param, base_seed=param['base_seed']) print("") helper = ModelHelper(kfold, param, x_train, y_train, y_train_with_non_scored, x_test, False) if param['is_train_model']: helper.train_models() total_train_preds = helper.test_models(is_predict_train_data=True) total_test_preds = helper.test_models() else: total_train_preds = helper.test_models(is_predict_train_data=True) total_test_preds = helper.test_models() # 模型预测所有train data的值都要写入到文件,要获得每个模型的val loss 和所有模型的val loss return total_train_preds, total_test_preds
def main(): args = parse_args() seed_everything(args.seed) app_train = joblib.load('../data/05_onehot/application_train.joblib') app_test = joblib.load('../data/05_onehot/application_test.joblib') sequences = read_all('../data/06_onehot_seq/') dims = joblib.load('../data/07_dims/dims05.joblib') app_dims = {} app_dims['application_train'] = dims.pop('application_train') app_dims['application_test'] = dims.pop('application_test') app_data = {'application_train': app_train, 'application_test': app_test} loader_maker = LoaderMaker(app_data, sequences, args, onehot=True) skf = StratifiedKFold(n_splits=5) folds = skf.split(app_train['SK_ID_CURR'], app_train['TARGET']) best_models = [] for train_index, val_index in folds: encoders = pretrain(app_train, app_test, sequences, dims, train_index, val_index, args) train_dataloader = loader_maker.make(train_index) val_dataloader = loader_maker.make(val_index) model = LightningModel( PretrainedR2N(app_dims, args.n_hidden, args.n_main, encoders), nn.BCEWithLogitsLoss(), train_dataloader, val_dataloader, args) name = '82_vaelstm_fine' trainer = HomeCreditTrainer(name, args.n_epochs, args.patience) trainer.fit(model) best_model = load_model(model, name, trainer.logger.version) best_models.append(best_model) # Predict test_dataloader = loader_maker.make(index=None, train=False) df_submission = predict(best_models, test_dataloader) df_submission.to_csv(f'../submission/{name}.csv', index=False)
def main(): args = parse_args() seed_everything(args.seed) if args.onehot: all_data = read_all(directory='../data/05_onehot') sequences = read_sequences(directory='../data/06_onehot_seq') else: all_data = read_all(directory='../data/03_powertransform') sequences = read_sequences(directory='../data/04_sequence') dims = get_dims(all_data) loader_maker = LoaderMaker(all_data, sequences, args, onehot=args.onehot) # CV name = '15_cnn-onehot' if args.onehot else '15_cnn-label' skf = StratifiedKFold(n_splits=5) folds = skf.split(all_data['application_train']['SK_ID_CURR'], all_data['application_train']['TARGET']) best_models = [] for train_index, val_index in folds: train_dataloader = loader_maker.make(train_index) val_dataloader = loader_maker.make(val_index) model = LightningModel(R2NCNN(dims, args.n_hidden, args.n_main), nn.BCEWithLogitsLoss(), train_dataloader, val_dataloader, args) trainer = HomeCreditTrainer(name, args.n_epochs, args.patience) trainer.fit(model) best_model = load_model(model, name, trainer.logger.version) best_models.append(best_model) # Predict test_dataloader = loader_maker.make(index=None, train=False) df_submission = predict(best_models, test_dataloader) filename = '../submission/15_r2n-cnn-onehot.csv' if args.onehot else '../submission/15_r2n-cnn-label.csv' df_submission.to_csv(filename, index=False)
def main(): args = parse_args() seed_everything(args.seed) if args.onehot: app_train = joblib.load('../data/05_onehot/application_train.joblib') app_test = joblib.load('../data/05_onehot/application_test.joblib') dims = get_dims({'application_train': app_train}) _, _, cont_dim = dims['application_train'] n_input = cont_dim else: app_train = joblib.load( '../data/03_powertransform/application_train.joblib') app_test = joblib.load( '../data/03_powertransform/application_test.joblib') dims = get_dims({'application_train': app_train}) cat_dims, emb_dims, cont_dim = dims['application_train'] n_input = emb_dims.sum() + cont_dim n_hidden = args.n_hidden # CV skf = StratifiedKFold(n_splits=5) folds = skf.split(app_train['SK_ID_CURR'], app_train['TARGET']) best_models = [] for train_index, val_index in folds: train_dataloader = make_dataloader(app_train, train_index, args.batch_size, onehot=args.onehot) val_dataloader = make_dataloader(app_train, val_index, args.batch_size, onehot=args.onehot) if args.onehot: network = MLPOneHot(n_input, n_hidden) else: network = MLP(cat_dims, emb_dims, n_input, n_hidden) model = LightningModel(network, nn.BCEWithLogitsLoss(), train_dataloader, val_dataloader, args) name = '13_mlp-onehot' if args.onehot else '13_mlp-label' trainer = HomeCreditTrainer(name, args.n_epochs, args.patience) trainer.fit(model) best_model = load_model(model, name, trainer.logger.version) best_models.append(best_model) # Predict test_dataloader = make_dataloader(app_test, None, args.batch_size, train=False, onehot=args.onehot) df_submission = predict(best_models, test_dataloader) filename = '../submission/13_mlp-onehot.csv' if args.onehot else '../submission/13_mlp-label.csv' df_submission.to_csv(filename, index=False)
def main(): args = parse_args() seed_everything(args.seed) app_train = joblib.load('../data/03_powertransform/application_train.joblib') app_test = joblib.load('../data/03_powertransform/application_test.joblib') sequences = read_all('../data/04_sequence/') dims = joblib.load('../data/07_dims/dims03.joblib') app_dims = {} app_dims['application_train'] = dims.pop('application_train') app_dims['application_test'] = dims.pop('application_test') mlflow.set_tracking_uri('../logs/mlruns') mlflow.set_experiment('HomeCredit') run_name = '91_dimlstm' params = vars(args) df_submission = app_test[['SK_ID_CURR']].copy() skf = StratifiedKFold(n_splits=5) folds = skf.split(app_train['SK_ID_CURR'], app_train['TARGET']) for i, (train_index, val_index) in enumerate(folds): # Train Encoder encoders = pretrain(app_train, sequences, dims, train_index, val_index, args) # Train LightGBM Model app_encoding_train = predict(app_train, encoders, sequences, args) x = app_encoding_train.drop(['SK_ID_CURR', 'TARGET'], axis=1) y = app_encoding_train['TARGET'] x_train, y_train = x.iloc[train_index], y.iloc[train_index] x_valid, y_valid = x.iloc[val_index], y.iloc[val_index] train_set = lgb.Dataset(x_train, y_train) valid_set = lgb.Dataset(x_valid, y_valid) model = lgb.train(params, train_set, valid_sets=[valid_set]) y_pred = model.predict(x_valid) auc = roc_auc_score(y_valid, y_pred) with mlflow.start_run(run_name=run_name): mlflow.log_params(params) mlflow.log_metric('auc', auc) # Predict app_encoding_test = predict(app_test, encoders, sequences, args) x_test = app_encoding_test.drop('SK_ID_CURR', axis=1) y_pred = model.predict(x_test) df_submission[f'pred_{i}'] = y_pred df_submission = df_submission.set_index('SK_ID_CURR').mean(axis=1).reset_index() df_submission.columns = ['SK_ID_CURR', 'TARGET'] df_submission.to_csv(f'../submission/{run_name}.csv', index=False)
def main(): args = parse_args() seed_everything(args.seed) app_train = joblib.load( '../data/03_powertransform/application_train.joblib') app_test = joblib.load('../data/03_powertransform/application_test.joblib') sequences = read_sequences('../data/04_sequence/') dims = joblib.load('../data/07_dims/dims03.joblib') dims.pop('application_train') dims.pop('application_test') for name, diminfo in dims.items(): cat = sequences[f'{name}_cat'] cont = sequences[f'{name}_cont'] train_loader = torch.utils.data.DataLoader( SequenceDataset(app_train, cat, cont), batch_size=args.batch_size, shuffle=True, num_workers=6, worker_init_fn=worker_init_fn) test_loader = torch.utils.data.DataLoader( SequenceDataset(app_test, cat, cont), batch_size=args.batch_size, shuffle=False, num_workers=6, worker_init_fn=worker_init_fn) model = DIMLSTMModule(diminfo, args.n_hidden, train_loader, test_loader, args) logdir = '../logs/21_dimlstm' path = pathlib.Path(logdir) / name if not path.exists(): path.mkdir(parents=True) logger = TensorBoardLogger(logdir, name=name) early_stopping = EarlyStopping(patience=args.patience, monitor='val_loss_main', mode='min') filepath = pathlib.Path( logdir) / name / f'version_{logger.version}' / 'checkpoints' model_checkpoint = ModelCheckpoint(str(filepath), monitor='val_loss_main', mode='min') trainer = pl.Trainer(default_save_path=logdir, gpus=-1, max_epochs=args.n_epochs, early_stop_callback=early_stopping, logger=logger, row_log_interval=100, checkpoint_callback=model_checkpoint) trainer.fit(model) best_model = load_model(model, name, trainer.logger.version, logdir=logdir) train_loader_no_shuffle = torch.utils.data.DataLoader( SequenceDataset(app_train, cat, cont), batch_size=args.batch_size, shuffle=False, num_workers=6, worker_init_fn=worker_init_fn) df_train = predict(name, best_model, train_loader_no_shuffle) df_test = predict(name, best_model, test_loader) df_encoding = pd.concat([df_train, df_test]) dump(df_encoding, f'../data/21_dimlstm/{name}.joblib')
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(7) args = parse_args() Path(args.save_path).mkdir(parents=True, exist_ok=True) entity = "demiurge" project = "melgan" load_from_run_id = args.load_from_run_id resume_run_id = args.resume_run_id restore_run_id = load_from_run_id or resume_run_id batch_size = args.batch_size # Getting initial run steps and epoch # if restore run, replace args steps = None if restore_run_id: api = wandb.Api() previous_run = api.run(f"{entity}/{project}/{restore_run_id}") steps = previous_run.lastHistoryStep prev_args = argparse.Namespace(**previous_run.config) args = vars(args) args.update(vars(prev_args)) args = Namespace(**args) args.batch_size = batch_size load_initial_weights = bool(restore_run_id) sampling_rate = args.sampling_rate ratios = args.ratios if isinstance(ratios, str): ratios = ratios.replace(" ", "") ratios = ratios.strip("][").split(",") ratios = [int(i) for i in ratios] ratios = np.array(ratios) if load_from_run_id and resume_run_id: raise RuntimeError("Specify either --load_from_id or --resume_run_id.") if resume_run_id: print(f"Resuming run ID {resume_run_id}.") elif load_from_run_id: print( f"Starting new run with initial weights from run ID {load_from_run_id}." ) else: print("Starting new run from scratch.") # read 1 line in train files to log dataset location train_files = Path(args.data_path) / "train_files.txt" with open(train_files, encoding="utf-8", mode="r") as f: file = f.readline() args.train_file_sample = str(file) wandb.init( entity=entity, project=project, id=resume_run_id, config=args, resume=True if resume_run_id else False, save_code=True, dir=args.save_path, notes=args.notes, ) print("run id: " + str(wandb.run.id)) print("run name: " + str(wandb.run.name)) root = Path(wandb.run.dir) root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) wandb.save("args.yml") ############################################### # The file modules.py is needed by the unagan # ############################################### wandb.save(mel2wav.modules.__file__, base_path=".") ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers, ratios=ratios).to(device) netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).to(device) fft = Audio2Mel( n_mel_channels=args.n_mel_channels, pad_mode=args.pad_mode, sampling_rate=sampling_rate, ).to(device) for model in [netG, netD, fft]: wandb.watch(model) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) if load_initial_weights: for model, filenames in [ (netG, ["netG.pt", "netG_prev.pt"]), (optG, ["optG.pt", "optG_prev.pt"]), (netD, ["netD.pt", "netD_prev.pt"]), (optD, ["optD.pt", "optD_prev.pt"]), ]: recover_model = False filepath = None for filename in filenames: try: run_path = f"{entity}/{project}/{restore_run_id}" print(f"Restoring {filename} from run path {run_path}") restored_file = wandb.restore(filename, run_path=run_path) filepath = restored_file.name model = load_state_dict_handleDP(model, filepath) recover_model = True break except RuntimeError as e: print("RuntimeError", e) print(f"recover model weight file: '{filename}'' failed") if not recover_model: raise RuntimeError( f"Cannot load model weight files for component {filenames[0]}." ) else: # store successfully recovered model weight file ("***_prev.pt") path_parent = Path(filepath).parent newfilepath = str(path_parent / filenames[1]) os.rename(filepath, newfilepath) wandb.save(newfilepath) if torch.cuda.device_count() > 1: netG = DP(netG).to(device) netD = DP(netD).to(device) fft = DP(fft).to(device) print(f"We have {torch.cuda.device_count()} gpus. Use data parallel.") else: print(f"We have {torch.cuda.device_count()} gpu.") ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=sampling_rate, ) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", sampling_rate * 4, sampling_rate=sampling_rate, augment=False, ) wandb.save(str(Path(args.data_path) / "train_files.txt")) wandb.save(str(Path(args.data_path) / "test_files.txt")) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) if len(train_loader) == 0: raise RuntimeError("Train dataset is empty.") if len(test_loader) == 0: raise RuntimeError("Test dataset is empty.") if not restore_run_id: steps = wandb.run.step start_epoch = steps // len(train_loader) print(f"Starting with epoch {start_epoch} and step {steps}.") ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] samples = [] melImages = [] num_fix_samples = args.n_test_samples - (args.n_test_samples // 2) cmap = cm.get_cmap("inferno") for i, x_t in enumerate(test_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() test_voc.append(s_t.to(device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), sampling_rate, audio) samples.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) melImage = s_t.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / (np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize((melImage.width * 4, melImage.height * 4)) melImages.append(wandb.Image(cmap(melImage), caption=f"sample {i}")) if i == num_fix_samples - 1: break # if not resume_run_id: wandb.log({"audio/original": samples}, step=start_epoch) wandb.log({"mel/original": melImages}, step=start_epoch) # else: # print("We are resuming, skipping logging of original audio.") costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 for epoch in range(start_epoch, start_epoch + args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(device).detach()) D_real = netD(x_t.to(device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) wandb.log( { "loss/discriminator": costs[-1][0], "loss/generator": costs[-1][1], "loss/feature_matching": costs[-1][2], "loss/mel_reconstruction": costs[-1][3], }, step=steps, ) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): samples = [] melImages = [] # fix samples for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), sampling_rate, pred_audio) samples.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) melImage = voc.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / ( np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize( # (melImage.width * 4, melImage.height * 4) # ) melImages.append( wandb.Image(cmap(melImage), caption=f"sample {i}")) wandb.log( { "audio/generated": samples, "mel/generated": melImages, "epoch": epoch, }, step=steps, ) # var samples source = [] pred = [] pred_mel = [] num_var_samples = args.n_test_samples - num_fix_samples for i, x_t in enumerate(test_loader): # source x_t = x_t.to(device) audio = x_t.squeeze().cpu() source.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) # pred s_t = fft(x_t).detach() voc = s_t.to(device) pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() pred.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) melImage = voc.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / ( np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize( # (melImage.width * 4, melImage.height * 4) # ) pred_mel.append( wandb.Image(cmap(melImage), caption=f"sample {i}")) # stop when reach log sample if i == num_var_samples - 1: break wandb.log( { "audio/var_original": source, "audio/var_generated": pred, "mel/var_generated": pred_mel, }, step=steps, ) print("Saving models ...") torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") wandb.save(str(root / "netG.pt")) wandb.save(str(root / "optG.pt")) torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") wandb.save(str(root / "netD.pt")) wandb.save(str(root / "optD.pt")) if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") wandb.save(str(root / "best_netD.pt")) wandb.save(str(root / "best_netG.pt")) print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(7) args = parse_args() Path(args.save_path).mkdir(parents=True, exist_ok=True) entity = "materialvision" project = "melganmv" load_from_run_id = args.load_from_run_id resume_run_id = args.resume_run_id restore_run_id = load_from_run_id or resume_run_id load_initial_weights = bool(restore_run_id) sampling_rate = args.sampling_rate if load_from_run_id and resume_run_id: raise RuntimeError("Specify either --load_from_id or --resume_run_id.") if resume_run_id: print(f"Resuming run ID {resume_run_id}.") elif load_from_run_id: print( f"Starting new run with initial weights from run ID {load_from_run_id}." ) else: print("Starting new run from scratch.") wandb.init( entity=entity, project=project, id=resume_run_id, config=args, resume=True if resume_run_id else False, save_code=True, dir=args.save_path, notes=args.notes, ) print("run id: " + str(wandb.run.id)) print("run name: " + str(wandb.run.name)) root = Path(wandb.run.dir) root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) wandb.save("args.yml") ############################################### # The file modules.py is needed by the unagan # ############################################### wandb.save(mel2wav.modules.__file__, base_path=".") ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers).to(device) netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).to(device) fft = Audio2Mel( n_mel_channels=args.n_mel_channels, pad_mode=args.pad_mode, sampling_rate=sampling_rate, ).to(device) for model in [netG, netD, fft]: wandb.watch(model) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) if load_initial_weights: for obj, filename in [ (netG, "netG.pt"), (optG, "optG.pt"), (netD, "netD.pt"), (optD, "optD.pt"), ]: run_path = f"{entity}/{project}/{restore_run_id}" print(f"Restoring {filename} from run path {run_path}") restored_file = wandb.restore(filename, run_path=run_path) obj.load_state_dict(torch.load(restored_file.name)) ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=sampling_rate, ) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", sampling_rate * 4, sampling_rate=sampling_rate, augment=False, ) wandb.save(str(Path(args.data_path) / "train_files.txt")) wandb.save(str(Path(args.data_path) / "test_files.txt")) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) if len(train_loader) == 0: raise RuntimeError("Train dataset is empty.") if len(test_loader) == 0: raise RuntimeError("Test dataset is empty.") # Getting initial run steps and epoch if load_from_run_id: api = wandb.Api() previous_run = api.run(f"{entity}/{project}/{restore_run_id}") steps = previous_run.lastHistoryStep else: steps = wandb.run.step start_epoch = steps // len(train_loader) print(f"Starting with epoch {start_epoch} and step {steps}.") ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] samples = [] for i, x_t in enumerate(test_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() test_voc.append(s_t.to(device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), sampling_rate, audio) samples.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) if i == args.n_test_samples - 1: break if not resume_run_id: wandb.log({"audio/original": samples}, step=0) else: print("We are resuming, skipping logging of original audio.") costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 for epoch in range(start_epoch, start_epoch + args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(device).detach()) D_real = netD(x_t.to(device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) wandb.log( { "loss/discriminator": costs[-1][0], "loss/generator": costs[-1][1], "loss/feature_matching": costs[-1][2], "loss/mel_reconstruction": costs[-1][3], }, step=steps, ) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): samples = [] for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), sampling_rate, pred_audio) samples.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) wandb.log( { "audio/generated": samples, "epoch": epoch, }, step=steps, ) print("Saving models ...") torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") wandb.save(str(root / "netG.pt")) wandb.save(str(root / "optG.pt")) torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") wandb.save(str(root / "netD.pt")) wandb.save(str(root / "optD.pt")) if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") wandb.save(str(root / "best_netD.pt")) wandb.save(str(root / "best_netG.pt")) print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def train(args): os.chdir(get_original_cwd()) run_name = 'each' if args.each else 'agg' run_name += '_submit' if args.submit else '_cv' logging.info('start ' + run_name) seed_everything(args.seed) if args.each: v_sales_dict = joblib.load( '../data/05_preprocess/each_item/v_sales_dict.joblib') data_count = joblib.load( '../data/05_preprocess/each_item/data_count.joblib') dims = joblib.load('../data/05_preprocess/each_item/dims.joblib') weight = joblib.load('../data/06_weight/weight_each.joblib') te = joblib.load('../data/07_te/each_te.joblib') else: v_sales_dict = joblib.load( '../data/05_preprocess/agg_item/v_sales_dict.joblib') data_count = joblib.load( '../data/05_preprocess/agg_item/data_count.joblib') dims = joblib.load('../data/05_preprocess/agg_item/dims.joblib') weight = joblib.load('../data/06_weight/weight_agg.joblib') te = joblib.load('../data/07_te/agg_te.joblib') v_sales = next(iter(v_sales_dict.values())) drop_columns = [ 'sort_key', 'id', 'cat_id', 'd', 'release_date', 'date', 'weekday', 'year', 'week_of_month', 'holidy' ] if not args.use_prices: drop_columns += [ 'release_ago', 'sell_price', 'diff_price', 'price_max', 'price_min', 'price_std', 'price_mean', 'price_trend', 'price_norm', 'diff_price_norm', 'price_nunique', 'dept_max', 'dept_min', 'dept_std', 'dept_mean', 'price_in_dept', 'mean_in_dept', 'cat_max', 'cat_min', 'cat_std', 'cat_mean', 'price_in_cat', 'mean_in_cat', 'price_in_month', 'price_in_year', ] cat_columns = [ 'aggregation_level', 'item_id', 'dept_id', 'store_id', 'state_id', 'month', 'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2', 'day_of_week' ] features = [ col for col in v_sales.columns if col not in drop_columns + [TARGET] ] is_cats = [col in cat_columns for col in features] cat_dims = [] emb_dims = [] for col in features: if col in cat_columns: cat_dims.append(dims['cat_dims'][col]) emb_dims.append(dims['emb_dims'][col]) dims = pd.DataFrame({'cat_dims': cat_dims, 'emb_dims': emb_dims}) logging.info('data loaded') if args.submit: logging.info('train for submit') # train model for submission index = 1 if args.useval else 2 valid_term = 2 train_index = index if args.patience == 0 else (index + valid_term) trainset = M5Dataset(v_sales_dict, data_count, features, weight, te, remove_last4w=train_index, min_data_4w=0, over_sample=args.over_sample) validset = M5ValidationDataset(trainset.data_dict, weight, te, remove_last4w=index, term=valid_term) train_loader = torch.utils.data.DataLoader( trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=get_worker_init_fn(args.seed)) valid_loader = torch.utils.data.DataLoader( validset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) model = M5MLPLSTMModel(is_cats, dims, n_hidden=args.n_hidden, dropout=args.dropout, use_te=args.use_te) criterion = M5Distribution(dist=args.dist, df=args.df) module = M5LightningModule(model, criterion, train_loader, valid_loader, None, args) trainer = M5Trainer(args.experiment, run_name, args.max_epochs, args.min_epochs, args.patience, args.val_check) trainer.fit(module) trainer.logger.experiment.log_artifact( trainer.logger.run_id, trainer.checkpoint_callback.kth_best_model) logging.info('predict') module.load_state_dict( torch.load( trainer.checkpoint_callback.kth_best_model)['state_dict']) # for reproducibility dmp_filename = '../data/cuda_rng_state_each.dmp' if args.each else '../data/cuda_rng_state_agg.dmp' torch.save(torch.cuda.get_rng_state(), dmp_filename) trainer.logger.experiment.log_artifact(trainer.logger.run_id, dmp_filename) val_acc, val_unc = predict(args, module, criterion, trainset.data_dict, weight, te, evaluation=False) eva_acc, eva_unc = predict(args, module, criterion, trainset.data_dict, weight, te, evaluation=True) submission_accuracy = pd.concat([val_acc, eva_acc]) submission_uncertainty = pd.concat([val_unc, eva_unc]) dump(submission_accuracy, submission_uncertainty, run_name) else: # local CV folds = list(range(3, -1, -1)) # [3, 2, 1, 0] for fold in folds: logging.info(f'train FOLD [{4-fold}/{len(folds)}]') valid_term = 2 if args.patience == 0: train_index = (fold + 1) * valid_term + 1 valid_index = (fold + 1) * valid_term + 1 test_index = fold * valid_term + 1 else: train_index = (fold + 2) * valid_term + 1 valid_index = (fold + 1) * valid_term + 1 test_index = fold * valid_term + 1 trainset = M5Dataset(v_sales_dict, data_count, features, weight, te, remove_last4w=train_index, over_sample=args.over_sample) validset = M5ValidationDataset(trainset.data_dict, weight, te, remove_last4w=valid_index, term=valid_term) testset = M5TestDataset(trainset.data_dict, weight, te, remove_last4w=test_index, term=valid_term) train_loader = torch.utils.data.DataLoader( trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=get_worker_init_fn(args.seed)) valid_loader = torch.utils.data.DataLoader( validset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) test_loader = torch.utils.data.DataLoader( testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) model = M5MLPLSTMModel(is_cats, dims, n_hidden=args.n_hidden, dropout=args.dropout, use_te=args.use_te) criterion = M5Distribution(dist=args.dist, df=args.df) module = M5LightningModule(model, criterion, train_loader, valid_loader, test_loader, args) fold_name = f'_{4-fold}-{len(folds)}' trainer = M5Trainer(args.experiment, run_name + fold_name, args.max_epochs, args.min_epochs, args.patience, args.val_check) trainer.fit(module) trainer.logger.experiment.log_artifact( trainer.logger.run_id, trainer.checkpoint_callback.kth_best_model) logging.info(f'test FOLD [{4-fold}/{len(folds)}]') module.load_state_dict( torch.load( trainer.checkpoint_callback.kth_best_model)['state_dict']) trainer.test() del trainset, validset, testset, train_loader, valid_loader, test_loader, model, criterion, module, trainer gc.collect()
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): inception = real_mean = real_cov = mean_latent = None # if args.eval_every > 0: # inception = nn.DataParallel(load_patched_inception_v3()).to(device) # inception.eval() # with open(args.inception, "rb") as f: # embeds = pickle.load(f) # real_mean = embeds["mean"] # real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break # Train Discriminator requires_grad(generator, False) requires_grad(discriminator, True) if args.debug: util.seed_everything(i) for step_index in range(args.n_step_d): real_img = next(loader).to(device) noise = mixing_noise(args.batch, args.latent, args.mixing, device) print("type of noise1",noise[0].shape) print("about to enter generator1") # print("noise1 shape is ",noise.shape) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) fake_pred = discriminator(fake_img) d_loss_real = F.softplus(-real_pred).mean() d_loss_fake = F.softplus(fake_pred).mean() d_loss_fake2 = 0 if args.lambda_nda < 1: fake_img2, _ = util.negative_augment(real_img, args.nda_type) if args.augment: fake_img2, _ = augment(fake_img2, ada_aug_p) fake_pred2 = discriminator(fake_img2) d_loss_fake2 = F.softplus(fake_pred2).mean() d_loss = ( d_loss_real + d_loss_fake * args.lambda_nda + d_loss_fake2 * (1-args.lambda_nda) ) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss # Train Generator requires_grad(generator, True) requires_grad(discriminator, False) if args.debug: util.seed_everything(i) noise = mixing_noise(args.batch, args.latent, args.mixing, device) # print("about to enter generator2") # print("noise2 shape is ", noise.shape) fake_img, _ = generator(noise) if args.augment: fake_img, _ = augment(fake_img, ada_aug_p) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) # print("about to enter generator3") # print("noise3 shape is ", noise.shape) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update G_ema # G_ema = G * (1-ema_beta) + G_ema * ema_beta ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5 ** (args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}" ) ) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() print("about to enter g_ema1") print("sample_z1 shape is ", sample_z.shape) sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}.png"), nrow=int(args.n_sample ** 0.5), normalize=True, value_range=(-1, 1), ) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write( ( f"{i:07d}; " f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"\n" ) ) if wandb and args.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): g_ema.eval() # if args.truncation < 1: # mean_latent = g_ema.mean_latent(4096) # features = extract_feature_from_samples( # g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device # ).numpy() # sample_mean = np.mean(features, 0) # sample_cov = np.cov(features, rowvar=False) # fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: # f.write(f"{i:07d}; sample: {float(fid):.4f};\n") if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
type=int, default=4, help="Number of negative samples for training set") parser.add_argument("--num_ng_test", type=int, default=100, help="Number of negative samples for test set") parser.add_argument("--out", default=True, help="save model or not") # set device and parameters args = parser.parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") writer = SummaryWriter() # seed for Reproducibility util.seed_everything(args.seed) # load data ml_1m = pd.read_csv(config.DATA_PATH, sep="::", names=['user_id', 'item_id', 'rating', 'timestamp'], engine='python') # set the num_users, items num_users = ml_1m['user_id'].nunique() + 1 num_items = ml_1m['item_id'].nunique() + 1 # construct the train and test datasets data = data_utils.NCF_Data(args, ml_1m) train_loader = data.get_train_instance() test_loader = data.get_test_instance()
def main(): device = "cuda:0" if torch.cuda.is_available() else "cpu" parser = argparse.ArgumentParser() parser.add_argument('--image_path', type=str, default="./data/dirty_mnist_2nd/") parser.add_argument('--label_path', type=str, default="./data/dirty_mnist_2nd_answer.csv") parser.add_argument('--kfold_idx', type=int, default=0) parser.add_argument('--model', type=str, default='efficientnet-b8') parser.add_argument('--epochs', type=int, default=2000) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--patient', type=int, default=8) parser.add_argument('--device', type=str, default=device) parser.add_argument('--resume', type=str, default=None) parser.add_argument('--comments', type=str, default=None) args = parser.parse_args() print('=' * 50) print('[info msg] arguments\n') for key, value in vars(args).items(): print(key, ":", value) print('=' * 50) assert os.path.isdir(args.image_path), 'wrong path' assert os.path.isfile(args.label_path), 'wrong path' if (args.resume): assert os.path.isfile(args.resume), 'wrong path' assert args.kfold_idx < 5 util.seed_everything(777) data_set = pd.read_csv(args.label_path) valid_idx_nb = int(len(data_set) * (1 / 5)) valid_idx = np.arange(valid_idx_nb * args.kfold_idx, valid_idx_nb * (args.kfold_idx + 1)) print('[info msg] validation fold idx !!\n') print(valid_idx) print('=' * 50) train_data = data_set.drop(valid_idx) valid_data = data_set.iloc[valid_idx] train_set = util.DatasetMNIST( image_folder=args.image_path, label_df=train_data, transforms=util.mnist_transforms['new_train']) valid_set = util.DatasetMNIST(image_folder=args.image_path, label_df=valid_data, transforms=util.mnist_transforms['valid']) train_data_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, ) valid_data_loader = torch.utils.data.DataLoader( valid_set, batch_size=args.batch_size, shuffle=False, ) model = None if (args.resume): model = EfficientNet.from_name(args.model, in_channels=1, num_classes=26, dropout_rate=0.5) model.load_state_dict(torch.load(args.resume)) print('[info msg] pre-trained weight is loaded !!\n') print(args.resume) print('=' * 50) else: print('[info msg] {} model is created\n'.format(args.model)) model = EfficientNet.from_pretrained(args.model, in_channels=1, num_classes=26, dropout_rate=0.5, advprop=True) print('=' * 50) if args.device == 'cuda' and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) ####### Wandb ###### wandb.init(project='dacon_dirty_mnist_new_aug') wandb.run.name = args.comments wandb.config.update(args) wandb.watch(model) #################### model.to(args.device) optimizer = torch.optim.Adam(model.parameters(), args.lr) criterion = torch.nn.MultiLabelSoftMarginLoss() scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='min', patience=2, factor=0.5, verbose=True) train_loss = [] train_acc = [] valid_loss = [] valid_acc = [] best_loss = float("inf") patient = 0 date_time = datetime.now().strftime("%m%d%H%M%S") SAVE_DIR = os.path.join('./save', date_time) print('[info msg] training start !!\n') startTime = datetime.now() for epoch in range(args.epochs): print('Epoch {}/{}'.format(epoch + 1, args.epochs)) train_epoch_loss, train_epoch_acc = util.train( train_loader=train_data_loader, model=model, loss_func=criterion, device=args.device, optimizer=optimizer, ) train_loss.append(train_epoch_loss) train_acc.append(train_epoch_acc) valid_epoch_loss, valid_epoch_acc = util.validate( valid_loader=valid_data_loader, model=model, loss_func=criterion, device=args.device, scheduler=scheduler, ) valid_loss.append(valid_epoch_loss) valid_acc.append(valid_epoch_acc) wandb.log({ "Train Acc": train_epoch_acc, "Valid Acc": valid_epoch_acc, "Train Loss": train_epoch_loss, "Valid Loss": valid_epoch_loss, }) if best_loss > valid_epoch_loss: patient = 0 best_loss = valid_epoch_loss Path(SAVE_DIR).mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), os.path.join(SAVE_DIR, 'model_best.pth.tar')) print('MODEL IS SAVED TO {}!!!'.format(date_time)) else: patient += 1 if patient > args.patient - 1: print('=======' * 10) print("[Info message] Early stopper is activated") break elapsed_time = datetime.now() - startTime train_loss = np.array(train_loss) train_acc = np.array(train_acc) valid_loss = np.array(valid_loss) valid_acc = np.array(valid_acc) best_loss_pos = np.argmin(valid_loss) print('=' * 50) print('[info msg] training is done\n') print("Time taken: {}".format(elapsed_time)) print("best loss is {} w/ acc {} at epoch : {}".format( best_loss, valid_acc[best_loss_pos], best_loss_pos)) print('=' * 50) print('[info msg] {} model weight and log is save to {}\n'.format( args.model, SAVE_DIR)) with open(os.path.join(SAVE_DIR, 'log.txt'), 'w') as f: for key, value in vars(args).items(): f.write('{} : {}\n'.format(key, value)) f.write('\n') f.write('total ecpochs : {}\n'.format(str(train_loss.shape[0]))) f.write('time taken : {}\n'.format(str(elapsed_time))) f.write('best_train_loss {} w/ acc {} at epoch : {}\n'.format( np.min(train_loss), train_acc[np.argmin(train_loss)], np.argmin(train_loss))) f.write('best_valid_loss {} w/ acc {} at epoch : {}\n'.format( np.min(valid_loss), valid_acc[np.argmin(valid_loss)], np.argmin(valid_loss))) plt.figure(figsize=(15, 5)) plt.subplot(1, 2, 1) plt.plot(train_loss, label='train loss') plt.plot(valid_loss, 'o', label='valid loss') plt.axvline(x=best_loss_pos, color='r', linestyle='--', linewidth=1.5) plt.legend() plt.subplot(1, 2, 2) plt.plot(train_acc, label='train acc') plt.plot(valid_acc, 'o', label='valid acc') plt.axvline(x=best_loss_pos, color='r', linestyle='--', linewidth=1.5) plt.legend() plt.savefig(os.path.join(SAVE_DIR, 'history.png'))
def train(args, loader, loader2, generator, encoder, discriminator, discriminator2, vggnet, g_optim, e_optim, d_optim, d2_optim, g_ema, e_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = { 'recx_score': torch.tensor(0.0, device=device), 'ae_fake': torch.tensor(0.0, device=device), 'ae_real': torch.tensor(0.0, device=device), 'pix': torch.tensor(0.0, device=device), 'vgg': torch.tensor(0.0, device=device), } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d2_module = None if discriminator2 is not None: if args.distributed: d2_module = discriminator2.module else: d2_module = discriminator2 # When joint training enabled, d_weight balances reconstruction loss and adversarial loss on # recontructed real images. This does not balance the overall AE loss and GAN loss. d_weight = torch.tensor(1.0, device=device) last_layer = None if args.use_adaptive_weight: if args.distributed: last_layer = generator.module.get_last_layer() else: last_layer = generator.get_last_layer() g_scale = 1 ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 r_t_dict = {'real': 0, 'fake': 0, 'recx': 0} # r_t stat if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] input_is_latent = args.latent_space != 'z' # Encode in z space? n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break if args.debug: util.seed_everything(i) real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator and Encoder requires_grad(generator, False) requires_grad(encoder, True) requires_grad(discriminator, True) requires_grad(discriminator2, True) for step_index in range(args.n_step_d): real_img = real_imgs[step_index] noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) fake_img_aug, _ = augment(fake_img, ada_aug_p) else: real_img_aug = real_img fake_img_aug = fake_img real_pred = discriminator(encoder(real_img_aug)[0]) fake_pred = discriminator(encoder(fake_img_aug)[0]) d_loss_real = F.softplus(-real_pred).mean() d_loss_fake = F.softplus(fake_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() d_loss_rec = 0. if args.lambda_rec_d > 0: latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=input_is_latent) if args.augment: rec_img, _ = augment(rec_img, ada_aug_p) rec_pred = discriminator(encoder(rec_img)[0]) d_loss_rec = F.softplus(rec_pred).mean() loss_dict["recx_score"] = rec_pred.mean() r_t_dict['recx'] = torch.sign( rec_pred).sum().item() / args.batch d_loss = d_loss_real + d_loss_fake * args.lambda_fake_d + d_loss_rec * args.lambda_rec_d loss_dict["d"] = d_loss discriminator.zero_grad() encoder.zero_grad() d_loss.backward() d_optim.step() e_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch r_t_dict['fake'] = torch.sign(fake_pred).sum().item() / args.batch d_regularize = args.d_reg_every > 0 and i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(encoder(real_img_aug)[0]) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() encoder.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() e_optim.step() loss_dict["r1"] = r1_loss # Train Generator requires_grad(generator, True) requires_grad(encoder, False) requires_grad(discriminator, False) requires_grad(discriminator2, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) if args.augment: fake_img_aug, _ = augment(fake_img, ada_aug_p) else: fake_img_aug = fake_img fake_pred = discriminator(encoder(fake_img_aug)[0]) g_loss_fake = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss_fake generator.zero_grad() (g_loss_fake * args.lambda_fake_g).backward() g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Train Encoder (and Generator) joint = (not args.no_joint) and (g_scale > 1e-6) # Train AE on fake samples (latent reconstruction) if args.lambda_rec_w + (args.lambda_pix_fake + args.lambda_vgg_fake + args.lambda_adv_fake) > 0: requires_grad(encoder, True) requires_grad(generator, joint) requires_grad(discriminator, False) requires_grad(discriminator2, False) for step_index in range(args.n_step_e): # mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing # noise = mixing_noise(args.batch, args.latent, mixing_prob, device) # fake_img, latent_fake = generator(noise, return_latents=True, detach_style=not args.no_detach_style) # if args.which_latent == 'w_tied': # latent_fake = latent_fake[:,0,:] # else: # latent_fake = latent_fake.view(args.batch, -1) # latent_pred, _ = encoder(fake_img) # ae_loss_fake = torch.mean((latent_pred - latent_fake.detach()) ** 2) ae_loss_fake = 0 mixing_prob = 0 if args.which_latent == 'w_tied' else args.mixing if args.lambda_rec_w > 0: noise = mixing_noise(args.batch, args.latent, mixing_prob, device) fake_img, latent_fake = generator( noise, return_latents=True, detach_style=not args.no_detach_style) if args.which_latent == 'w_tied': latent_fake = latent_fake[:, 0, :] else: latent_fake = latent_fake.view(args.batch, -1) latent_pred, _ = encoder(fake_img) ae_loss_fake = torch.mean( (latent_pred - latent_fake.detach())**2) if args.lambda_pix_fake + args.lambda_vgg_fake + args.lambda_adv_fake > 0: pix_loss = vgg_loss = adv_loss = torch.tensor( 0., device=device) noise = mixing_noise(args.batch, args.latent, mixing_prob, device) fake_img, _ = generator(noise, detach_style=False) fake_img = fake_img.detach() latent_pred, _ = encoder(fake_img) rec_img, _ = generator([latent_pred], input_is_latent=input_is_latent) if args.lambda_pix_fake > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - fake_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, fake_img) if args.lambda_vgg_fake > 0: vgg_loss = torch.mean( (vggnet(fake_img) - vggnet(rec_img))**2) ae_loss_fake = (ae_loss_fake + pix_loss * args.lambda_pix_fake + vgg_loss * args.lambda_vgg_fake) loss_dict["ae_fake"] = ae_loss_fake if joint: encoder.zero_grad() generator.zero_grad() (ae_loss_fake * args.lambda_rec_w).backward() e_optim.step() if args.g_decay is not None: scale_grad(generator, g_scale) # Do NOT update F (or generator.style). Grad should be zero when style # is detached in generator, but we explicitly zero it, just in case. if not args.no_detach_style: generator.style.zero_grad() g_optim.step() else: encoder.zero_grad() (ae_loss_fake * args.lambda_rec_w).backward() e_optim.step() # Train AE on real samples (image reconstruction) if args.lambda_pix + args.lambda_vgg + args.lambda_adv > 0: requires_grad(encoder, True) requires_grad(generator, joint) requires_grad(discriminator, False) requires_grad(discriminator2, False) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=input_is_latent) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean( (vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p) else: rec_img_aug = rec_img rec_pred = discriminator(encoder(rec_img_aug)[0]) adv_loss = g_nonsaturating_loss(rec_pred) if args.use_adaptive_weight and i >= args.disc_iter_start: nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg g_loss = adv_loss * args.lambda_adv d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) ae_loss_real = (pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + d_weight * adv_loss * args.lambda_adv) loss_dict["ae_real"] = ae_loss_real loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss if joint: encoder.zero_grad() generator.zero_grad() ae_loss_real.backward() e_optim.step() if args.g_decay is not None: scale_grad(generator, g_scale) g_optim.step() else: encoder.zero_grad() ae_loss_real.backward() e_optim.step() if args.g_decay is not None: g_scale *= args.g_decay # Update EMA ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5**(args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, 0 if args.no_ema_g else accum) accumulate(e_ema, e_module, 0 if args.no_ema_e else accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() ae_real_val = loss_reduced["ae_real"].mean().item() ae_fake_val = loss_reduced["ae_fake"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() recx_score_val = loss_reduced["recx_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; " f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"d_weight: {d_weight.item():.4f}; ")) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] # Reconstruction of real images latent_x, _ = e_ema(sample_x) rec_real, _ = g_ema([latent_x], input_is_latent=input_is_latent) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), rec_real.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real)) ref_vgg_loss = torch.mean( (vggnet(sample_x) - vggnet(rec_real))**2) if vggnet is not None else 0 # Fixed fake samples and reconstructions sample_gz, _ = g_ema([sample_z]) latent_gz, _ = e_ema(sample_gz) rec_fake, _ = g_ema([latent_gz], input_is_latent=input_is_latent) sample = torch.cat( (sample_gz.reshape(args.n_sample // nrow, nrow, *nchw), rec_fake.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(( f"{i:07d}; " f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"ae_fake: {ae_fake_val:.4f}; ae_real: {ae_real_val:.4f}; " f"g: {g_loss_val:.4f}; path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; " f"real_score: {real_score_val:.4f}; fake_score: {fake_score_val:.4f}; recx_score: {recx_score_val:.4f}; " f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; " f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; " f"d_weight: {d_weight.item():.4f}; " f"\n")) if wandb and args.wandb: wandb.log({ "Generator": g_loss_val, "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, }) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): fid_sa = fid_re = fid_sr = 0 g_ema.eval() e_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) # Sample FID if 'fid_sample' in args.which_metric: features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_sa = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Sample reconstruction FID if 'fid_sample_recon' in args.which_metric: features = extract_feature_from_samples( g_ema, inception, args.truncation, mean_latent, 64, args.n_sample_fid, args.device, mode='recon', encoder=e_ema, input_is_latent=input_is_latent, ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_sr = calc_fid(sample_mean, sample_cov, real_mean, real_cov) # Real reconstruction FID if 'fid_recon' in args.which_metric: features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, input_is_latent=input_is_latent, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write( f"{i:07d}; sample: {float(fid_sa):.4f}; rec_fake: {float(fid_sr):.4f}; rec_real: {float(fid_re):.4f};\n" ) if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "d2": d2_module.state_dict() if args.decouple_d else None, "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "d2_optim": d2_optim.state_dict() if args.decouple_d else None, "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
fold = 5 batch_size = 64 epochs = 6 y_target = np.zeros((y_train.shape[0], len(ALL), y_train.shape[1])) y_test_pred = np.zeros((X_test.shape[0], len(ALL) * fold, y_train.shape[1])) n = 0 for name in DL: # print('='*80) seed = SEED * (n + 1) kfold = list( KFold(n_splits=fold, random_state=seed, shuffle=True).split(X_train, y_train)) for i, (train_index, val_index) in enumerate(kfold): X, y, val_X, val_y = X_train[train_index], y_train[ train_index], X_train[val_index], y_train[val_index] util.seed_everything(seed + i) model = models.getModel(param, name) # if i == 0: print(model.summary()) filepath = param['subject_ckp_path'] + name + '-' + str(i + 1) if not os.path.exists(filepath): reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=1, min_lr=0.0001, verbose=2) checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=2, save_best_only=True, mode='min') model.fit(X,
def main(dataset_name, model_name, epochs, batch_size, noise_type, noise_ratio, verbose=1, alpha=util.ALPHA, temperature=16, is_dropout=False, percentage=1): K.clear_session() os.environ['CUDA_VISIBLE_DEVICES'] = '0' seed_everything() # folders to be used noise_base_path = 'nr_{}'.format( noise_ratio) if not is_dropout else 'nr_{}_do'.format(noise_ratio) folders = { 'logbase': '{}/logs_{}/'.format(dataset_name, percentage), 'logbase_nr': '{}/logs_{}/{}/{}/'.format(dataset_name, percentage, noise_base_path, model_name), 'logdir': '{}/logs_{}/{}/{}/{}/'.format(dataset_name, percentage, noise_base_path, model_name, noise_type), 'modelbase': '{}/models/'.format(dataset_name), 'noisebase': '{}/noisylabels/'.format(dataset_name), 'noisedir': '{}/noisylabels/{}/'.format(dataset_name, noise_type), 'dataset': '{}/dataset'.format(dataset_name) } # if log file already exis"ts dont run it again if isfile(folders['logdir'] + 'model/model.h5') and isfile(folders['logdir'] + 'model/model.json'): print('Logs exists, skipping run...') return # clean empty logs if there is any clean_empty_logs() # create necessary folders create_folders(folders['dataset'], folders['logdir']) # generate noisy labels y_train_noisy, y_test_noisy = prep_noisylabels(dataset_name, folders, noise_type, noise_ratio, verbose, alpha, temperature, is_dropout) # load dataset with noisy labels dataset = get_data(dataset_name, y_noisy=y_train_noisy, y_noisy_test=y_test_noisy) dataset.get_percentage(percentage) # stats before training print( 'Dataset: {}, model: {}, noise_type: {}, noise_ratio: {}, epochs: {}, batch: {} , dropout: {}' .format(dataset.name, model_name, noise_type, noise_ratio, epochs, batch_size, is_dropout)) dataset.get_stats() dataset.save_cm_train(folders['logdir'] + 'corrupted_data.png') # train model if model_name == 'coteaching': model1 = get_model(dataset, model_name, is_dropout=is_dropout) model2 = get_model(dataset, model_name, is_dropout=is_dropout) model = train_coteaching(dataset, model1, model2, epochs, batch_size, folders['logdir']) else: #cm = np.load('{}/models/xy/npy/test_cm.npy'.format(dataset_name)) cm = dataset.get_cm_train() cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] model = get_model(dataset, model_name, cm, is_dropout=is_dropout) model = train(dataset, model, epochs, batch_size, folders['logdir'], verbose=verbose) # performance analysis postprocess(dataset, model, noise_type, noise_ratio, folders, y_test_noisy) K.clear_session()
def train(args, loader, loader2, generator, encoder, discriminator, vggnet, g_optim, e_optim, d_optim, g_ema, e_ema, device): inception = real_mean = real_cov = mean_latent = None if args.eval_every > 0: inception = nn.DataParallel(load_patched_inception_v3()).to(device) inception.eval() with open(args.inception, "rb") as f: embeds = pickle.load(f) real_mean = embeds["mean"] real_cov = embeds["cov"] if get_rank() == 0: if args.eval_every > 0: with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") if args.log_every > 0: with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(f"Name: {getattr(args, 'name', 'NA')}\n{'-'*50}\n") loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 d_loss_val = r1_val = real_score_val = recx_score_val = 0 loss_dict = { "d": torch.tensor(0.0, device=device), "r1": torch.tensor(0.0, device=device) } avg_pix_loss = util.AverageMeter() avg_vgg_loss = util.AverageMeter() if args.distributed: g_module = generator.module e_module = encoder.module d_module = discriminator.module else: g_module = generator e_module = encoder d_module = discriminator d_weight = torch.tensor(1.0, device=device) last_layer = None if args.use_adaptive_weight: if args.distributed: last_layer = generator.module.get_last_layer() else: last_layer = generator.get_last_layer() # accum = 0.5 ** (32 / (10 * 1000)) ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 r_t_stat = 0 r_t_dict = {'real': 0, 'recx': 0} # r_t stat g_scale = 1 if args.augment and args.augment_p == 0: ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, args.ada_every, device) sample_z = torch.randn(args.n_sample, args.latent, device=device) sample_x = load_real_samples(args, loader) if sample_x.ndim > 4: sample_x = sample_x[:, 0, ...] n_step_max = max(args.n_step_d, args.n_step_e) requires_grad(g_ema, False) requires_grad(e_ema, False) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break if args.debug: util.seed_everything(i) real_imgs = [next(loader).to(device) for _ in range(n_step_max)] # Train Discriminator if args.lambda_adv > 0: requires_grad(generator, False) requires_grad(encoder, False) requires_grad(discriminator, True) for step_index in range(args.n_step_d): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) rec_img_aug, _ = augment(rec_img, ada_aug_p) else: real_img_aug = real_img rec_img_aug = rec_img real_pred = discriminator(real_img_aug) rec_pred = discriminator(rec_img_aug) d_loss_real = F.softplus(-real_pred).mean() d_loss_rec = F.softplus(rec_pred).mean() loss_dict["real_score"] = real_pred.mean() loss_dict["recx_score"] = rec_pred.mean() d_loss = d_loss_real + d_loss_rec * args.lambda_rec_d loss_dict["d"] = d_loss discriminator.zero_grad() d_loss.backward() d_optim.step() if args.augment and args.augment_p == 0: ada_aug_p = ada_augment.tune(real_pred) r_t_stat = ada_augment.r_t_stat # Compute batchwise r_t r_t_dict['real'] = torch.sign(real_pred).sum().item() / args.batch d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True if args.augment: real_img_aug, _ = augment(real_img, ada_aug_p) else: real_img_aug = real_img real_pred = discriminator(real_img_aug) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss r_t_dict['recx'] = torch.sign(rec_pred).sum().item() / args.batch # Train AutoEncoder requires_grad(encoder, True) requires_grad(generator, True) requires_grad(discriminator, False) if args.debug: util.seed_everything(i) pix_loss = vgg_loss = adv_loss = torch.tensor(0., device=device) for step_index in range(args.n_step_e): real_img = real_imgs[step_index] latent_real, _ = encoder(real_img) rec_img, _ = generator([latent_real], input_is_latent=True) if args.lambda_pix > 0: if args.pix_loss == 'l2': pix_loss = torch.mean((rec_img - real_img)**2) elif args.pix_loss == 'l1': pix_loss = F.l1_loss(rec_img, real_img) if args.lambda_vgg > 0: vgg_loss = torch.mean((vggnet(real_img) - vggnet(rec_img))**2) if args.lambda_adv > 0: if args.augment: rec_img_aug, _ = augment(rec_img, ada_aug_p) else: rec_img_aug = rec_img rec_pred = discriminator(rec_img_aug) adv_loss = g_nonsaturating_loss(rec_pred) if args.use_adaptive_weight and i >= args.disc_iter_start: nll_loss = pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg g_loss = adv_loss * args.lambda_adv d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) ae_loss = (pix_loss * args.lambda_pix + vgg_loss * args.lambda_vgg + d_weight * adv_loss * args.lambda_adv) loss_dict["ae"] = ae_loss loss_dict["pix"] = pix_loss loss_dict["vgg"] = vgg_loss loss_dict["adv"] = adv_loss encoder.zero_grad() generator.zero_grad() ae_loss.backward() e_optim.step() if args.g_decay is not None: scale_grad(generator, g_scale) g_scale *= args.g_decay g_optim.step() g_regularize = args.g_reg_every > 0 and i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() # Update EMA ema_nimg = args.ema_kimg * 1000 if args.ema_rampup is not None: ema_nimg = min(ema_nimg, i * args.batch * args.ema_rampup) accum = 0.5**(args.batch / max(ema_nimg, 1e-8)) accumulate(g_ema, g_module, 0 if args.no_ema_g else accum) accumulate(e_ema, e_module, 0 if args.no_ema_e else accum) loss_reduced = reduce_loss_dict(loss_dict) ae_loss_val = loss_reduced["ae"].mean().item() path_loss_val = loss_reduced["path"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() pix_loss_val = loss_reduced["pix"].mean().item() vgg_loss_val = loss_reduced["vgg"].mean().item() adv_loss_val = loss_reduced["adv"].mean().item() if args.lambda_adv > 0: d_loss_val = loss_reduced["d"].mean().item() r1_val = loss_reduced["r1"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() recx_score_val = loss_reduced["recx_score"].mean().item() avg_pix_loss.update(pix_loss_val, real_img.shape[0]) avg_vgg_loss.update(vgg_loss_val, real_img.shape[0]) if get_rank() == 0: pbar.set_description(( f"d: {d_loss_val:.4f}; ae: {ae_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; " f"d_weight: {d_weight.item():.4f}; " f"pix: {pix_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; adv: {adv_loss_val:.4f}" )) if i % args.log_every == 0: with torch.no_grad(): g_ema.eval() e_ema.eval() nrow = int(args.n_sample**0.5) nchw = list(sample_x.shape)[1:] # Reconstruction of real images latent_x, _ = e_ema(sample_x) rec_real, _ = g_ema([latent_x], input_is_latent=True) sample = torch.cat( (sample_x.reshape(args.n_sample // nrow, nrow, *nchw), rec_real.reshape(args.n_sample // nrow, nrow, *nchw)), 1) utils.save_image( sample.reshape(2 * args.n_sample, *nchw), os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-recon.png"), nrow=nrow, normalize=True, value_range=(-1, 1), ) ref_pix_loss = torch.sum(torch.abs(sample_x - rec_real)) ref_vgg_loss = torch.mean( (vggnet(sample_x) - vggnet(rec_real))**2) if vggnet is not None else 0 # Fixed fake samples and reconstructions sample, _ = g_ema([sample_z]) utils.save_image( sample, os.path.join(args.log_dir, 'sample', f"{str(i).zfill(6)}-sample.png"), nrow=int(args.n_sample**0.5), normalize=True, value_range=(-1, 1), ) with open(os.path.join(args.log_dir, 'log.txt'), 'a+') as f: f.write(( f"{i:07d}; " f"d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean_path: {mean_path_length_avg:.4f}; " f"augment: {ada_aug_p:.4f}; {'; '.join([f'{k}: {r_t_dict[k]:.4f}' for k in r_t_dict])}; " f"real_score: {real_score_val:.4f}; recx_score: {recx_score_val:.4f}; " f"pix: {avg_pix_loss.avg:.4f}; vgg: {avg_vgg_loss.avg:.4f}; " f"ref_pix: {ref_pix_loss.item():.4f}; ref_vgg: {ref_vgg_loss.item():.4f}; " f"d_weight: {d_weight.item():.4f}; " f"\n")) if wandb and args.wandb: wandb.log({ "Discriminator": d_loss_val, "Augment": ada_aug_p, "Rt": r_t_stat, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Path Length": path_length_val, }) if args.eval_every > 0 and i % args.eval_every == 0: with torch.no_grad(): fid_sa = fid_re = fid_sr = 0 g_ema.eval() e_ema.eval() if args.truncation < 1: mean_latent = g_ema.mean_latent(4096) # Real reconstruction FID if 'fid_recon' in args.which_metric: features = extract_feature_from_reconstruction( e_ema, g_ema, inception, args.truncation, mean_latent, loader2, args.device, mode='recon', ).numpy() sample_mean = np.mean(features, 0) sample_cov = np.cov(features, rowvar=False) fid_re = calc_fid(sample_mean, sample_cov, real_mean, real_cov) with open(os.path.join(args.log_dir, 'log_fid.txt'), 'a+') as f: f.write(f"{i:07d}; rec_real: {float(fid_re):.4f};\n") if i % args.save_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"{str(i).zfill(6)}.pt"), ) if i % args.save_latest_every == 0: torch.save( { "g": g_module.state_dict(), "e": e_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "e_ema": e_ema.state_dict(), "g_optim": g_optim.state_dict(), "e_optim": e_optim.state_dict(), "d_optim": d_optim.state_dict(), "args": args, "ada_aug_p": ada_aug_p, "iter": i, }, os.path.join(args.log_dir, 'weight', f"latest.pt"), )
def average(preds_test_on_all_models, model_val_loss_dict): losses = np.array( [val_loss for k, val_loss in model_val_loss_dict.items()]) scores = 1. - losses sum_scores = sum(scores) weights = [x / sum_scores for x in scores] return np.average(np.array(preds_test_on_all_models), axis=0, weights=weights) if __name__ == '__main__': seed_everything(param['base_seed']) # copy models to work directory input_model_dir = "../input/data-and-model" if not os.path.exists(param['output_root_dir']): os.makedirs(param['output_root_dir']) if os.path.exists(input_model_dir): shutil.rmtree(param['output_root_dir']) shutil.copytree(input_model_dir, param['output_root_dir']) start_time = time.time() if param['compute_val_loss_only']: compuate_val_loss() else: train_test_all()
parser.add_argument( "--ada_length", type=int, default=500 * 1000, help= "target duraing to reach augmentation probability for adaptive augmentation", ) parser.add_argument( "--ada_every", type=int, default=256, help="probability update interval of the adaptive augmentation", ) args = parser.parse_args() util.seed_everything() n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = n_gpu > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() args.n_latent = int(np.log2(args.size)) * 2 - 2 # used in Generator args.latent = 512 # fixed, dim of w or z (same size) if args.which_latent == 'w_plus': args.latent_full = args.latent * args.n_latent elif args.which_latent == 'w_tied':
default=1.0, help='weight for shoe class loss') parser.add_argument('--output_path', type=str, default='snapshot/', help='save file path') parser.add_argument('--log_file', type=str, default='shoe_adapt_seg_source(train_val)_target(test)', help='log file') parser.add_argument('--is_writer', action='store_true', help='whether you use SummaryWriter or not') args = parser.parse_args() seed_everything(args.seed) os.environ['PYTHONASHSEED'] = str(args.seed) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu config = { 'output_path': args.output_path, 'path': { 'log': args.output_path + '/log/', 'scalar': args.output_path + '/scalar/', 'model': args.output_path + '/model/' }, 'is_writer': args.is_writer } # Create output Dir mkdir_if_not_exist(config['path']['log'])