def best_auc_func(params): model = train.new_model(pre_train_path) train_ctw.meta_train(model, params) loss, auc = train.test(model, warm=False) global best_auc global best_params if auc > best_auc: torch.save(model.state_dict(), ours_path) best_auc = auc best_params = params print('best:{}'.format(best_auc)) return {'loss': -auc, 'status': STATUS_OK}
def hyper_param_analysis(rhos, alphas, path): ps = { 'alpha': None, 'amsgrad': False, 'batch_n_ID': 50, 'gamma': 1.0, 'lr': 0.01, 'p_lr': 0.0001, 'p_lr_decay': 1.0, 'rho': None, 'weight_decay': 1e-8 } # if os.path.exists(path): # os.remove(path) with open(path, 'a') as f: f.write('# {}\n'.format(conf.model_type)) for rho in rhos: for alpha in alphas: ps['rho'] = rho ps['alpha'] = alpha # path cold_path = "models/mer-trained-rho({})-alpha({}).pth".format( rho, alpha) warm_path = 'models/mer-tested-rho({})-alpha({}).pth'.format( rho, alpha) # cold model = train.new_model(pre_train_path) model = train_ctw.meta_train(model, ps, const.train_n_epoch) torch.save(model.state_dict(), cold_path) cold_loss, cold_auc = train.test(cold_path, warm=False) # warm ours_warm = train.meta_test(cold_path, 'ours', False) warm_losses, warm_aucs = zip(*ours_warm['ours']) losses = ',' + ','.join(map(str, warm_losses)) aucs = ',' + ','.join(map(str, warm_aucs)) torch.save(model.state_dict(), warm_path) msg = ','.join( [str(rho), str(alpha), str(cold_loss), str(cold_auc)]) + losses + aucs + '\n' print(msg) f.write(msg)