def load_gan_model(base_model_name, noise_size=100, feat_size=200, epoch=30, ndh=256, ngh=512, critic_iters=5, lr=2e-5, reg_ratio=0., decoder='linear', top_k=30, add_zero=False, pool_mode='last'): gan_hyper = get_gan_model_name(lr, ndh, ngh, critic_iters, reg_ratio, decoder, top_k, add_zero, pool_mode) model_path = f'epoch{epoch}_{gan_hyper}_{base_model_name}' log(f'Loading GAN from {model_path}...') gan_states = torch.load(f'{MODEL_DIR}/{model_path}', map_location=lambda storage, loc: storage) generator = ConditionalGenerator(noise_size, noise_size, ngh, feat_size) discriminator = ConditionalDiscriminator(feat_size, noise_size, ndh, 1) word_emb = gan_states['label_rnn']['encoder.weight'].data.numpy() label_rnn = RNNLabelEncoder(word_emb) label_rnn.load_state_dict(gan_states['label_rnn']) generator.load_state_dict(gan_states['generator']) discriminator.load_state_dict(gan_states['discriminator']) for m in [generator, discriminator, label_rnn]: for p in m.parameters(): p.requires_grad = False return generator, discriminator, label_rnn, model_path
def prepare_code_data(model, device, to_torch=False): code_desc = model.code_idx_matrix.data.cpu().numpy() code_mask = model.code_idx_mask.data.cpu().numpy() if os.path.exists(ICD_CODE_DESC_DATA_PATH): with np.load(ICD_CODE_DESC_DATA_PATH) as f: all_words_indices, word_emb = f['arr_0'], f['arr_1'] else: all_words_indices = np.sort(np.unique(code_desc)) word_emb = model.emb.embed.weight[all_words_indices].data.cpu().numpy() np.savez(ICD_CODE_DESC_DATA_PATH, all_words_indices, word_emb) log(f'In total {len(all_words_indices)} words for labels...') word_idx_to_keyword_idx = dict( zip(all_words_indices, np.arange(len(all_words_indices)))) assert word_idx_to_keyword_idx[0] == 0, "Padding index should match" code_desc = np.asarray([[word_idx_to_keyword_idx[w] for w in kw] for kw in code_desc], dtype=int) if to_torch: code_desc = torch.from_numpy(code_desc).to(device) code_mask = torch.from_numpy(code_mask).to(device) return code_desc, code_mask, word_emb
def load_adj_matrix(codes_to_targets): hier_codes = read_hier_codes() codes_to_extended_targets = copy.copy(codes_to_targets) for code in hier_codes: if code not in codes_to_extended_targets: codes_to_extended_targets[code] = len(codes_to_extended_targets) n = len(codes_to_extended_targets) adj_matrix = np.zeros((n, n)) missed = 0 codes_to_parents = defaultdict(list) for code in codes_to_extended_targets: if code in hier_codes: for neighbor in hier_codes[code]: c_idx = codes_to_extended_targets[code] n_idx = codes_to_extended_targets[neighbor] if neighbor in code: # parent codes_to_parents[c_idx].append(n_idx) adj_matrix[c_idx, n_idx] = 1 adj_matrix[n_idx, c_idx] = 1 else: missed += 1 log(f'Graph has {n} nodes, {missed} codes has no hierarchy...') return codes_to_extended_targets, adj_matrix, codes_to_parents
def run(self, tile_loader_config_file): """ Executes the GenomicsDB loader """ if self.debug: helper.log("[Loader:Run] Starting mpirun subprocess") processArgs = list() if self.NUM_PROCESSES > 1: processArgs.extend([ self.MPIRUN, "-np", str(self.NUM_PROCESSES), self.HOSTFLAG, self.HOSTS ]) if self.IF_INCLUDE is not None: processArgs.extend( ["--mca", "btl_tcp_if_include", self.IF_INCLUDE]) if self.ENV is not None: for env in self.ENV.split(","): processArgs.extend(["-x", env]) processArgs.extend([self.EXEC, tile_loader_config_file]) if self.debug: helper.log("Args: {0} ".format(processArgs)) pipe = subprocess.Popen(processArgs, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, error = pipe.communicate() if pipe.returncode != 0: raise Exception( "subprocess run: {0}\nFailed with stdout: \n-- \n{1} \n--\nstderr: \n--\n{2} \n--" .format(" ".join(processArgs), output, error))
def extract_text_files(mimic_dir, save_dir): patient_dict = get_patient_data(mimic_dir) text_save_dir = f'{save_dir}/text_files/' make_folder(text_save_dir) label_save_dir = f'{save_dir}/label_files/' make_folder(label_save_dir) total_txt_count = 0 for (subject_id, hadm_id) in tqdm.tqdm(patient_dict, desc='Extracting text files'): icd9_dict = patient_dict[(subject_id, hadm_id)][1] all_descriptions = [] for category, description in patient_dict[(subject_id, hadm_id)][0].keys(): notes = patient_dict[(subject_id, hadm_id)][0][(category, description)] all_descriptions.extend(notes) # writing description notes text_save_path = f'{text_save_dir}/{subject_id}_{hadm_id}_notes.txt' concat_and_write(all_descriptions, text_save_path) # writing icd labels label_save_path = f'{label_save_dir}/{subject_id}_{hadm_id}_labels.txt' f = open(label_save_path, 'w') for key in icd9_dict: f.write('{}, {}\n'.format(key, icd9_dict[key])) f.close() total_txt_count += 1 log(f'Written {total_txt_count} text files to {save_dir}')
def tokenize_raw_text(save_dir): text_save_dir = os.path.join(save_dir, 'text_files') numpy_vectors_save_dir = os.path.join(save_dir, 'numpy_vectors') remove_folder(numpy_vectors_save_dir) make_folder(numpy_vectors_save_dir) hadms = [] for filename in os.listdir(text_save_dir): if ".txt" in filename: hadm = filename.replace(".txt", "") hadms.append(hadm) log(f"Total number of text files in set: {len(hadms)}") log(f'Loading vocab dict saved during from {VOCAB_DICT_PATH}') with open(VOCAB_DICT_PATH, 'rb') as f: vocab = pickle.load(f) tokenizer = Tokenizer(vocab) for hadm in tqdm.tqdm(hadms, desc='Tokenizing raw patient notes'): text = open(os.path.join(text_save_dir, str(hadm) + ".txt"), "r").read() words = tokenizer.process(text) vector = [] for word in words: if word in vocab: vector.append(vocab[word]) elif tokenizer.only_numerals(word) and ( len(vector) == 0 or vector[-1] != vocab["<NUM>"]): vector.append(vocab["<NUM>"]) mat = np.array(vector) # saving word indices to file write_file = os.path.join(numpy_vectors_save_dir, f"{hadm}.npy") np.save(write_file, mat)
def load_word_embedding(): if os.path.exists(EMBEDDING_PATH): with open(EMBEDDING_PATH, 'rb') as f: word_embedding = pickle.load(f) code_to_idx = pickle.load(f) log(f'W emb size {word_embedding.shape}') return word_embedding, code_to_idx else: log(f'Please download ') raise ValueError(f'Please download embedding file to {EMBEDDING_PATH}')
def load_pretrained_state_dict(self, pretrained_dict): model_dict = self.state_dict() for k in pretrained_dict: if k in model_dict: model_dict[k] = pretrained_dict[k] self.load_state_dict(model_dict) for name, param in self.named_parameters(): if name in pretrained_dict: param.requires_grad = False log(f'Freeze {name} in training...')
def load_data(train_notes, train_labels, dev_notes, dev_labels, codes_to_targets, max_note_len): log('Preloading data in memory...') train_x, train_y, train_masks = preload_data(train_notes, train_labels, codes_to_targets, max_note_len, save_path=CACHE_PATH) code_data_list = get_code_data_list(train_x, train_y, train_masks) dev_x, dev_y, dev_masks = preload_data(dev_notes, dev_labels, codes_to_targets, max_note_len) return code_data_list, train_x, train_y, train_masks, dev_x, dev_y, dev_masks
def load_first_stage_model(model_path, device=None): log(f'Loading pretrained model from {model_path}') state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) pretrained_dict = OrderedDict() for k in state_dict: if k.split('.')[0] in { 'emb', 'conv_modules', 'proj', 'output_proj', 'graph_label_encoder' }: pretrained_dict[k] = state_dict[k].to(device) return pretrained_dict
def get_nearest_for_zero(eval_code_size, code_feat_list, label_emb, data, step_size=8): fs, ls, kws, kwms = data zeroshot_codes = [] rest_codes = set(code_feat_list.keys()) for i in range(eval_code_size): if i not in code_feat_list: zeroshot_codes.append(i) label_emb = label_emb.data.cpu().numpy() zeroshot_emb = label_emb[zeroshot_codes] cos_dist = cosine_similarity(zeroshot_emb, label_emb) n = len(fs) zeroshot_neighbors = dict() sims = np.ones(n) for z, dist in zip(zeroshot_codes, cos_dist): z_fs = [] kw = [] m = [] sim = [] for nei in np.argsort(-dist)[1:]: if nei in rest_codes: zeroshot_neighbors[z] = nei n_aug = len(code_feat_list[nei]) nei_idx = code_feat_list[nei].l z_fs.append(fs[nei_idx]) kw.append(kws[nei_idx]) m.append(kwms[nei_idx]) sim.append(np.ones(n_aug) * dist[nei]) break s = len(fs) fs = np.vstack([fs] + z_fs) e = len(fs) kws = np.vstack([kws] + kw) kwms = np.vstack([kwms] + m) sims = np.concatenate([sims] + sim) data_idx = list(range(s, e)) ls = np.concatenate([ls, np.ones(len(data_idx)) * z]) code_feat_list[z] = inf_list_iterator(data_idx, step_size=step_size) log(f'Added {len(fs) - n} data for zero-shot labels') return fs, ls, kws, kwms, sims.astype(np.float32)
def get_gan_model_name(lr, ndh, ngh, critic_iters, reg_ratio, decoder, top_k, add_zero, pool_mode): gan_hyper = f'gan_lr{lr}_ndh{ndh}_ngh{ngh}_diter{critic_iters}' if reg_ratio > 0.: gan_hyper += f'_dec{decoder}{reg_ratio}_kw{top_k}' if add_zero: gan_hyper += '_z' if pool_mode in RNN_POOL: gan_hyper += f'_{RNN_POOL[pool_mode]}' log(f'{gan_hyper}') return gan_hyper
def preload_data(train_notes, train_labels, codes_to_targets, max_note_len=2000, idx_offset=1, save_path=None): if save_path is not None and os.path.exists(save_path): # if cache exists with np.load(save_path, allow_pickle=True) as f: return f['arr_0'], f['arr_1'], f['arr_2'] x = [] y = [] mask = [] row = 0 for note_path, labels in zip(train_notes, train_labels): code_idx = [codes_to_targets[code] for code in labels] y.append(code_idx) m = np.ones(max_note_len) xx = np.zeros(max_note_len) note = np.load(note_path) + idx_offset # shift 1 for padding if len(note) < max_note_len: m[len(note):] = 0 xx[:len(note)] = note else: xx = note[:max_note_len] x.append(xx) mask.append(m) row += 1 x = np.vstack(x).astype(int) y = np.asarray(y) mask = np.vstack(mask).astype(np.float32) if save_path is not None: log(f'Saving training data cache to {save_path}...') np.savez(save_path, x, y, mask) return x, y, mask
def init_emb_and_code_input(codes_to_targets): word_emb, code_to_idx = load_word_embedding() code_idx_matrix = [] code_idx_mask = [] max_code_len = max([len(idx) for idx in code_to_idx.values()]) log(f'Max code description {max_code_len}') targets_to_code = dict((v, k) for k, v in codes_to_targets.items()) for target in targets_to_code: code_idx = code_to_idx[targets_to_code[target]] mask = np.zeros(max_code_len) mask[:len(code_idx)] = 1 if len(code_idx) < max_code_len: code_idx = code_idx + [0] * (max_code_len - len(code_idx)) code_idx_matrix.append(code_idx) code_idx_mask.append(mask) code_idx_mask = np.asarray(code_idx_mask, dtype=np.float32) code_idx_mask = torch.from_numpy(code_idx_mask) code_idx_matrix = np.asarray(code_idx_matrix, dtype=int) code_idx_matrix = torch.from_numpy(code_idx_matrix) return word_emb, code_idx_matrix, code_idx_mask
def get_features_for_labels(model, label_indices): feat_path = model.pretrain_name.replace('.model', '.npz') feat_path = f'{FEATURE_DIR}/{feat_path}' assert os.path.exists(feat_path), f'Features should be at {feat_path}' log(f'Loading pretrained features from {feat_path}') with np.load(feat_path) as f: fs, ls = f['arr_0'], f['arr_1'] fs = np.maximum(fs, 0) label_indices = set(label_indices) label_feats = defaultdict(list) for x, y in zip(fs, ls): if y in label_indices: label_feats[y].append(x) _, code_centroids = get_code_feat_list(fs, ls) centroids = np.zeros((model.eval_code_size, fs.shape[1])) for i, code in enumerate(code_centroids): centroids[code] = code_centroids[code] return label_feats, centroids
def get_patient_data(mimic_dir): read_file = f'{mimic_dir}/NOTEEVENTS.csv' log(f'Reading {read_file} ...') df_notes = pandas.read_csv(read_file, low_memory=False, dtype=str) read_file = f'{mimic_dir}/DIAGNOSES_ICD.csv' log(f'Reading {read_file} ...') df_icds = pandas.read_csv(read_file, low_memory=False, dtype=str) all_notes = df_notes['TEXT'] all_note_types = df_notes['CATEGORY'] all_note_descriptions = df_notes['DESCRIPTION'] subject_ids_notes = df_notes['SUBJECT_ID'] hadm_ids_notes = df_notes['HADM_ID'] subject_ids_icd = df_icds['SUBJECT_ID'] hadm_ids_icd = df_icds['HADM_ID'] seq_nums_icd = df_icds['SEQ_NUM'] icd9_codes = df_icds['ICD9_CODE'] patient_dict = { (subject_id, hadm_id): [{}, {}] for subject_id, hadm_id in zip(subject_ids_notes, hadm_ids_notes) } # staring with icd code labels and collecting only those subject_id, # hadm_id pairs with at least one non-nan icd label for (subject_id, hadm_id, seq_num, icd9_code) in zip(subject_ids_icd, hadm_ids_icd, seq_nums_icd, icd9_codes): try: # there are cases where subject id, hadm id pairs are present in icd code data but not in noteevents data. # checking for nan, will fail for string then go to except and put in patient dict if not math.isnan(seq_num): patient_dict[(subject_id, hadm_id)][1][seq_num] = icd9_code except TypeError: try: patient_dict[(subject_id, hadm_id)][1][seq_num] = icd9_code except KeyError: # if not in admissions data pass for (subject_id, hadm_id, note, note_type, note_description) in zip(subject_ids_notes, hadm_ids_notes, all_notes, all_note_types, all_note_descriptions): if is_discharge_summary(note_type): if (note_type, note_description) in patient_dict[(subject_id, hadm_id)][0]: patient_dict[(subject_id, hadm_id)][0][(note_type, note_description)].append(note) else: patient_dict[(subject_id, hadm_id)][0][(note_type, note_description)] = [note] to_remove = [] for (subject_id, hadm_id) in patient_dict: if len(patient_dict[(subject_id, hadm_id)][0]) == 0 or len( patient_dict[(subject_id, hadm_id)][1]) == 0: to_remove.append((subject_id, hadm_id)) for key in to_remove: patient_dict.pop(key) log(f'Total number of (subject_id, hadm_id) with discharge summary, with at least 1 code: {len(patient_dict)}' ) return patient_dict
log('Evaluating on dev set...') dev_true, dev_score = eval_wrapper(dev_x, dev_y, dev_masks) log_eval_metrics(0, dev_score, dev_true, dev_freq_indices, dev_few_shot_indices, dev_zero_shot_indices) log('Evaluating on test set...') test_true, test_score = eval_wrapper(test_x, test_y, test_masks) log_eval_metrics(0, test_score, test_true, test_freq_indices, test_few_shot_indices, test_zero_shot_indices) if __name__ == '__main__': config = get_base_config() if config.evaluate: log('Evaluating base model...') eval_trained(eval_batch_size=config.eval_batch_size, max_note_len=config.max_note_len, loss=config.loss, gpu=config.gpu, save_model=config.save_model, graph_encoder=config.graph_encoder, class_margin=config.class_margin, C=config.C) else: log('Training base model...') train(lr=config.lr, batch_size=config.batch_size, eval_batch_size=config.eval_batch_size, num_epochs=config.num_epochs, max_note_len=config.max_note_len,
def train_generative(lr=1e-4, num_epochs=30, critic_iters=1, max_note_len=2000, gpu="cuda:0", loss='bce', graph_encoder='conv', batch_size=64, C=0., class_margin=False, ndh=256, ngh=512, save_every=10, reg_ratio=0., top_k=10, decoder='linear', add_zero=False, pool_mode='last'): pprint.pprint(locals(), stream=sys.stderr) gan_hyper = get_gan_model_name(lr, ndh, ngh, critic_iters, reg_ratio, decoder, top_k, add_zero, pool_mode) device = torch.device(gpu if torch.cuda.is_available() else "cpu") train_data = Dataset('train') dev_data = Dataset('dev') test_data = Dataset('test') train_notes, train_labels = train_data.get_data() dev_notes, dev_labels = dev_data.get_data() log(f'Loaded {len(train_notes)} train data, {len(dev_notes)} dev data...') n_train_data = len(train_notes) train_codes = train_data.get_all_codes() dev_codes = dev_data.get_all_codes() test_codes = test_data.get_all_codes() all_codes = train_codes.union(dev_codes).union(test_codes) all_codes = sorted(all_codes) codes_to_targets = codes_to_index_labels(all_codes, False) extended_codes_to_targets, adj_matrix, _ = load_adj_matrix( codes_to_targets) eval_code_size = len(codes_to_targets) frequent_codes, few_shot_codes, zero_shot_codes, codes_counter = split_codes_by_count( train_labels, dev_labels, train_codes, dev_codes) eval_indices = code_to_indices(dev_codes, codes_to_targets) frequent_indices = code_to_indices(frequent_codes, codes_to_targets) few_shot_indices = code_to_indices(few_shot_codes, codes_to_targets) zero_shot_indices = code_to_indices(zero_shot_codes, codes_to_targets) log(f'Evaluating on {len(eval_indices)} codes, {len(frequent_indices)} frequent codes, ' f'{len(few_shot_indices)} few shot codes and {len(zero_shot_indices)} zero shot codes...' ) target_count = targets_to_count(codes_to_targets, codes_counter) word_emb, code_idx_matrix, code_idx_mask = init_emb_and_code_input( extended_codes_to_targets) num_neighbors = torch.from_numpy(adj_matrix.sum(axis=1).astype(np.float32)) adj_matrix = torch.from_numpy(adj_matrix.astype(np.float32)) loss_fn = get_loss_fn(loss, reduction='sum') # init model log(f'Building model on {device}...') model = ConvLabelAttnGAN( word_emb, code_idx_matrix, code_idx_mask, adj_matrix, num_neighbors, loss_fn, eval_code_size=eval_code_size, graph_encoder=graph_encoder, target_count=target_count if class_margin else None, C=C) # load stage 1 base feature extractor model pretrained_model_path = f"{MODEL_DIR}/{model.pretrain_name}" pretrained_state_dict = load_first_stage_model( model_path=pretrained_model_path, device=device) model.load_pretrained_state_dict(pretrained_state_dict) # set to sparse here model.adj_matrix = model.adj_matrix.to_sparse() model.to(device) _, label_emb = model.get_froze_label_emb() clf_emb = label_emb graph_label_emb = label_emb[:, int(label_emb.shape[1] // 2):] label_emb = graph_label_emb feat_path = model.pretrain_name.replace('.model', '.npz') if not os.path.exists(f'{FEATURE_DIR}/{feat_path}'): log(f'Saving features to {feat_path}...') train_x, train_y, train_masks = preload_data(train_notes, train_labels, codes_to_targets, max_note_len, save_path=CACHE_PATH) train_keywords = load_note_keywords(train_notes) save_features(model, train_x, train_y, train_masks, eval_code_size, device, train_keywords, save_path=feat_path) fs, ls, kws, kwms = load_features(feat_path) code_feat_list, _ = get_code_feat_list(fs, ls) if add_zero: fs, ls, kws, kwms, sims = get_nearest_for_zero(eval_code_size, code_feat_list, clf_emb, (fs, ls, kws, kwms)) else: sims = None fs, ls, kws, kwms = fs.astype( np.float32), ls.astype(int), kws.astype(int), kwms.astype(np.float32) if reg_ratio > 0.: log(f'Predicting top {top_k} words...') kws = kws[:, :top_k] kwms = kwms[:, :top_k] all_keywords_indices = sorted(np.unique(kws)) log(f'In total {len(all_keywords_indices)} keywords...') word_idx_to_keyword_idx = dict( zip(all_keywords_indices, np.arange(len(all_keywords_indices)))) all_keywords_indices = torch.LongTensor(all_keywords_indices).to(device) kws = np.asarray([[word_idx_to_keyword_idx[w] for w in kw] for kw in kws], dtype=int) log('Activating features...') fs = np.maximum(fs, 0) label_size = label_emb.size(1) * 2 # concat with repr from RNN noise_size = label_size generator = ConditionalGenerator(noise_size, noise_size, ngh, model.feat_size) discriminator = ConditionalDiscriminator(model.feat_size, noise_size, ndh, 1) log(f'Encoding label using RNN, label hidden size {label_size}...') code_desc, code_mask, word_emb = prepare_code_data(model, device, to_torch=True) label_rnn = RNNLabelEncoder(word_emb) rnn_params = list(label_rnn.parameters())[1:] if reg_ratio > 0.: keyword_emb = model.emb.embed.weight.detach()[all_keywords_indices] keyword_predictor = load_decoder(decoder, model.feat_size, model.embed_size, keyword_emb) else: keyword_predictor = nn.Identity() generator.to(device) discriminator.to(device) keyword_predictor.to(device) label_rnn.to(device) d_params = list(discriminator.parameters()) + rnn_params optimizer_d = optim.Adam(d_params, lr=lr, betas=(0.5, 0.999)) g_params = list(generator.parameters()) + list( keyword_predictor.parameters()) optimizer_g = optim.Adam(g_params, lr=lr, betas=(0.5, 0.999)) one = torch.ones([]).to(device) mone = one * -1 n_data = len(fs) def inf_data_sampler(b): while True: batches = simple_iterate_minibatch(fs, ls, b, shuffle=True) for batch in batches: yield batch def get_rnn_emb(label_indices, labels=None): desc_x = code_desc[label_indices][:, :20] desc_m = code_mask[label_indices][:, :20] labels_rnn_emb = label_rnn.forward_enc(desc_x, desc_m, None, training=False, pool_mode=pool_mode) return torch.cat([labels, labels_rnn_emb], dim=1) def generate(label_indices=None, labels=None): if labels is None: assert label_indices is not None labels = label_emb[label_indices] labels = get_rnn_emb(label_indices, labels) labels = Variable(labels) b = labels.size(0) noises = torch.randn(b, noise_size).to(labels.device) noises = Variable(noises) feats = generator.forward(noises, labels) return feats log(f'Start training WGAN-GP with {n_data} pretrained features') it = 0 num_batches = n_data // batch_size disc_sampler = inf_data_sampler(batch_size) gen_sampler = inf_data_sampler(batch_size) for epoch in range(num_epochs): train_g_losses = [] train_d_losses = [] train_r_losses = [] train_k_losses = [] # train one epoch with torch.set_grad_enabled(True): model.eval() keyword_predictor.train() discriminator.train() generator.train() label_rnn.train() for p in model.parameters(): p.requires_grad = False for _ in range(num_batches): # train discriminator for p in discriminator.parameters(): p.requires_grad = True for it_c in range(critic_iters): sampled_idx = next(disc_sampler) real_feats, label_indices = torch.from_numpy( fs[sampled_idx]), torch.from_numpy(ls[sampled_idx]) real_feats, label_indices = real_feats.to( device), label_indices.to(device) sim_weight = 1 if sims is None else torch.from_numpy( sims[sampled_idx]).to(device) optimizer_d.zero_grad() labels = label_emb[label_indices] labels = get_rnn_emb(label_indices, labels) real_feats_v = Variable(real_feats) labels_v = Variable(labels) real_logits = discriminator.forward(real_feats_v, labels_v) critic_d_real = (real_logits * sim_weight).mean() critic_d_real.backward( mone, retain_graph=True if add_zero else False) fake_feats = generate(label_indices, labels=labels_v) fake_feats = torch.relu(fake_feats) fake_logits = discriminator.forward( fake_feats.detach(), labels_v) critic_d_fake = (fake_logits * sim_weight).mean() critic_d_fake.backward( one, retain_graph=True if add_zero else False) gp = calc_gradient_penalty(discriminator, real_feats, fake_feats.data, labels) gp.backward() d_cost = critic_d_fake - critic_d_real # + gp train_d_losses.append(d_cost.data.cpu().numpy()) optimizer_d.step() # train generator for p in discriminator.parameters(): # reset requires_grad p.requires_grad = False # avoid computation sampled_idx = next(gen_sampler) real_feats, label_indices = torch.from_numpy( fs[sampled_idx]), torch.from_numpy(ls[sampled_idx]) real_feats, label_indices = real_feats.to( device), label_indices.to(device) sim_weight = 1 if sims is None else torch.from_numpy( sims[sampled_idx]).to(device) optimizer_g.zero_grad() # Generate a batch of data labels = label_emb[label_indices] labels = get_rnn_emb(label_indices, labels) labels_v = Variable(labels) fake_feats = generate(label_indices, labels=labels_v) fake_feats = torch.relu(fake_feats) recon_loss = F.mse_loss(fake_feats, real_feats, reduction='mean') train_r_losses.append(recon_loss.data.cpu().numpy()) if reg_ratio > 0: keyword_indices, keyword_masks = torch.from_numpy(kws[sampled_idx]), \ torch.from_numpy(kwms[sampled_idx]) keyword_indices, keyword_masks = keyword_indices.to( device), keyword_masks.to(device) keyword_loss = keyword_predictor(fake_feats, keyword_indices, keyword_masks, labels_v) if not isinstance(sim_weight, int): keywords_weight = sim_weight.masked_fill( sim_weight != 1, 0.) else: keywords_weight = sim_weight keyword_loss = torch.mean(keyword_loss * keywords_weight) train_k_losses.append(keyword_loss.data.cpu().numpy()) else: keyword_loss = 0 fake_logits = discriminator.forward(fake_feats, labels_v) critic_g_fake = (fake_logits * sim_weight).mean() g_cost = -critic_g_fake train_g_losses.append(g_cost.data.cpu().numpy()) g_loss = g_cost + reg_ratio * keyword_loss g_loss.backward() optimizer_g.step() it += 1 log(f'Epoch {epoch}, disc loss={np.mean(train_d_losses):.4f}, ' f'gen loss={np.mean(train_g_losses):.4f}, ' f'mse loss={np.mean(train_r_losses):.4f}, ' f'key loss={np.mean(train_k_losses):.4f}') # eval on few / zero shot examples with torch.set_grad_enabled(False): model.eval() discriminator.eval() generator.eval() label_rnn.eval() keyword_predictor.eval() sample_num = 100 dev_scores = [] for code in few_shot_indices + zero_shot_indices: syn_codes = [code] * sample_num gen_feats = generate(syn_codes) scores = torch.sigmoid( torch.mul(torch.relu(gen_feats), clf_emb[code]).sum(-1)) dev_scores.append(scores.data.cpu().numpy()) dev_scores = np.concatenate(dev_scores) dev_preds = np.round(dev_scores) log(f'\tF/Z: gen probs={np.mean(dev_scores) * 100:.2f}, ' f'gen acc={np.mean(dev_preds == 1) * 100:.2f} ') start_saving = 20 if (epoch + 1) % save_every == 0 and epoch + 1 >= start_saving: gan_model = { 'generator': generator.state_dict(), 'discriminator': discriminator.state_dict(), 'label_rnn': label_rnn.state_dict() } torch.save( gan_model, f'{MODEL_DIR}/epoch{epoch + 1}_{gan_hyper}_{model.pretrain_name}' )
start_saving = 20 if (epoch + 1) % save_every == 0 and epoch + 1 >= start_saving: gan_model = { 'generator': generator.state_dict(), 'discriminator': discriminator.state_dict(), 'label_rnn': label_rnn.state_dict() } torch.save( gan_model, f'{MODEL_DIR}/epoch{epoch + 1}_{gan_hyper}_{model.pretrain_name}' ) if __name__ == '__main__': gan_config = get_gan_config() log('Training GAN model...') train_generative(gpu=gan_config.gpu, graph_encoder=gan_config.graph_encoder, class_margin=gan_config.class_margin, C=gan_config.C, num_epochs=gan_config.num_epochs, batch_size=gan_config.batch_size, add_zero=gan_config.add_zero, critic_iters=gan_config.critic_iters, lr=gan_config.lr, ndh=gan_config.ndh, ngh=gan_config.ngh, reg_ratio=gan_config.reg_ratio, decoder=gan_config.decoder, top_k=gan_config.top_k, save_every=gan_config.save_every,
def test_printers(): helper.log("test") helper.progressPrint("test")
def eval_trained(eval_batch_size=16, max_note_len=2000, loss='bce', gpu='cuda:1', save_model=True, graph_encoder='conv', class_margin=False, C=0.): pprint.pprint(locals(), stream=sys.stderr) device = torch.device(gpu if torch.cuda.is_available() else "cpu") train_data = Dataset('train') dev_data = Dataset('dev') test_data = Dataset('test') train_notes, train_labels = train_data.get_data() dev_notes, dev_labels = dev_data.get_data() test_notes, test_labels = test_data.get_data() log(f'Loaded {len(train_notes)} train data, {len(dev_notes)} dev data...') n_train_data = len(train_notes) train_codes = train_data.get_all_codes() dev_codes = dev_data.get_all_codes() test_codes = test_data.get_all_codes() all_codes = train_codes.union(dev_codes).union(test_codes) all_codes = sorted(all_codes) codes_to_targets = codes_to_index_labels(all_codes, False) dev_eval_code_size = len(codes_to_targets) frequent_codes, few_shot_codes, zero_shot_codes, codes_counter = split_codes_by_count( train_labels, dev_labels, train_codes, dev_codes) eval_code_size = len(codes_to_targets) dev_freq_codes, dev_few_shot_codes, dev_zero_shot_codes, codes_counter = \ split_codes_by_count(train_labels, dev_labels, train_codes, dev_codes, 5) test_freq_codes, test_few_shot_codes, test_zero_shot_codes, _ = \ split_codes_by_count(train_labels, test_labels, train_codes, test_codes, 5) dev_eval_indices = code_to_indices(dev_codes, codes_to_targets) dev_freq_indices = code_to_indices(dev_freq_codes, codes_to_targets) dev_few_shot_indices = code_to_indices(dev_few_shot_codes, codes_to_targets) dev_zero_shot_indices = code_to_indices(dev_zero_shot_codes, codes_to_targets) test_eval_indices = code_to_indices(test_codes, codes_to_targets) test_freq_indices = code_to_indices(test_freq_codes, codes_to_targets) test_few_shot_indices = code_to_indices(test_few_shot_codes, codes_to_targets) test_zero_shot_indices = code_to_indices(test_zero_shot_codes, codes_to_targets) extended_codes_to_targets, adj_matrix, codes_to_parents = load_adj_matrix( codes_to_targets) target_count = targets_to_count(codes_to_targets, codes_counter) eval_code_size = len(codes_to_targets) log(f'Building model on {device}...') word_emb, code_idx_matrix, code_idx_mask = init_emb_and_code_input( extended_codes_to_targets) num_neighbors = torch.from_numpy( adj_matrix.sum(axis=1).astype(np.float32)).to(device) adj_matrix = torch.from_numpy(adj_matrix.astype(np.float32)).to( device) # L x L loss_fn = get_loss_fn(loss, reduction='sum') model = ConvLabelAttnModel( word_emb, code_idx_matrix, code_idx_mask, adj_matrix, num_neighbors, loss_fn, graph_encoder=graph_encoder, eval_code_size=eval_code_size, target_count=target_count if class_margin else None, C=C) pretrained_model_path = f"{MODEL_DIR}/{model.name}" pretrained_dict = load_first_stage_model(pretrained_model_path, device) model_dict = model.state_dict() for k in pretrained_dict: if k in model_dict: model_dict[k] = pretrained_dict[k] model.load_state_dict(model_dict) # set to sparse here model.adj_matrix = model.adj_matrix.to_sparse() model.to(device) log('Preloading data in memory...') dev_x, dev_y, dev_masks = preload_data(dev_notes, dev_labels, codes_to_targets, max_note_len) test_x, test_y, test_masks = preload_data(test_notes, test_labels, codes_to_targets, max_note_len) def eval_wrapper(x, y, masks): y_true = [] y_score = [] with torch.set_grad_enabled(False): model.eval() for batch in iterate_minibatch(x, y, masks, eval_code_size, batch_size=eval_batch_size, shuffle=False): x, y, mask, _ = batch x, y, mask = x.to(device), y.to(device), mask.to(device) y_true.append(y.cpu().numpy()[:, :dev_eval_code_size]) # forward pass logits, _ = model.forward(x, y, mask) probs = torch.sigmoid(logits[:, :dev_eval_code_size]) # eval stats y_score.append(probs.cpu().numpy()) y_score = np.vstack(y_score) y_true = np.vstack(y_true).astype(int) return y_true, y_score log('Evaluating on dev set...') dev_true, dev_score = eval_wrapper(dev_x, dev_y, dev_masks) log_eval_metrics(0, dev_score, dev_true, dev_freq_indices, dev_few_shot_indices, dev_zero_shot_indices) log('Evaluating on test set...') test_true, test_score = eval_wrapper(test_x, test_y, test_masks) log_eval_metrics(0, test_score, test_true, test_freq_indices, test_few_shot_indices, test_zero_shot_indices)
def finetune_on_gan(eval_zero=True, lr=1e-5, batch_size=8, eval_batch_size=16, neg_iters=1, max_note_len=2000, gpu="cuda:0", loss='bce', graph_encoder='conv', l2_ratio=5e-4, top_k=10, finetune_epochs=10, gan_batch_size=64, syn_num=20, C=0., class_margin=False, gan_epoch=30, ndh=256, ngh=512, critic_iters=5, gan_lr=2e-5, reg_ratio=0., decoder='linear', add_zero=False, pool_mode='last'): seed() pprint.pprint(locals(), stream=sys.stderr) device = torch.device(gpu if torch.cuda.is_available() else "cpu") train_data = Dataset('train') dev_data = Dataset('dev') test_data = Dataset('test') train_notes, train_labels = train_data.get_data() dev_notes, dev_labels = dev_data.get_data() test_notes, test_labels = test_data.get_data() log(f'Loaded {len(train_notes)} train data, {len(dev_notes)} dev data...') n_train_data = len(train_notes) train_codes = train_data.get_all_codes() dev_codes = dev_data.get_all_codes() test_codes = test_data.get_all_codes() all_codes = train_codes.union(dev_codes).union(test_codes) all_codes = sorted(all_codes) codes_to_targets = codes_to_index_labels(all_codes, False) extended_codes_to_targets, adj_matrix, codes_to_parents = load_adj_matrix( codes_to_targets) eval_code_size = len(codes_to_targets) dev_freq_codes, dev_few_shot_codes, dev_zero_shot_codes, codes_counter = \ split_codes_by_count(train_labels, dev_labels, train_codes, dev_codes) test_freq_codes, test_few_shot_codes, eval_zero_shot_codes, _ = \ split_codes_by_count(train_labels, test_labels, train_codes, test_codes) dev_eval_indices = code_to_indices(dev_codes, codes_to_targets) dev_freq_indices = code_to_indices(dev_freq_codes, codes_to_targets) dev_few_shot_indices = code_to_indices(dev_few_shot_codes, codes_to_targets) dev_zero_shot_indices = code_to_indices(dev_zero_shot_codes, codes_to_targets) test_eval_indices = code_to_indices(test_codes, codes_to_targets) test_freq_indices = code_to_indices(test_freq_codes, codes_to_targets) test_few_shot_indices = code_to_indices(test_few_shot_codes, codes_to_targets) eval_zero_shot_indices = code_to_indices(eval_zero_shot_codes, codes_to_targets) log(f'Developing on {len(dev_eval_indices)} codes, {len(dev_freq_indices)} frequent codes, ' f'{len(dev_few_shot_indices)} few shot codes and {len(dev_zero_shot_indices)} zero shot codes...' ) target_count = targets_to_count(codes_to_targets, codes_counter) # if class_margin else None zero_shot_indices = np.union1d(dev_zero_shot_indices, eval_zero_shot_indices) few_shot_indices = np.union1d(dev_few_shot_indices, test_few_shot_indices) syn_indices = zero_shot_indices if eval_zero else few_shot_indices log(f'Synthesizing {len(syn_indices)} codes...') new_target_count = copy.copy(target_count) new_target_count[syn_indices] += syn_num word_emb, code_idx_matrix, code_idx_mask = init_emb_and_code_input( extended_codes_to_targets) num_neighbors = torch.from_numpy( adj_matrix.sum(axis=1).astype(np.float32)) # .to(device) adj_matrix = torch.from_numpy(adj_matrix.astype( np.float32)) # .to(device) # L x L loss_fn = get_loss_fn(loss, reduction='sum') # init model log(f'Building model on {device}...') model = ConvLabelAttnGAN( word_emb, code_idx_matrix, code_idx_mask, adj_matrix, num_neighbors, loss_fn, eval_code_size=eval_code_size, graph_encoder=graph_encoder, target_count=target_count if class_margin else None, C=C) # load stage 1 model pretrained_model_path = f"{MODEL_DIR}/{model.pretrain_name}" pretrained_state_dict = load_first_stage_model( model_path=pretrained_model_path, device=device) model.load_pretrained_state_dict(pretrained_state_dict) # set to sparse here model.adj_matrix = model.adj_matrix.to_sparse() model.to(device) _, pretrain_label_emb = model.get_froze_label_emb() label_emb = pretrain_label_emb[:, int(pretrain_label_emb.shape[1]) // 2:] label_size = label_emb.size(1) label_size *= 2 noise_size = label_size code_desc, code_mask, _ = prepare_code_data(model, device, to_torch=True) generator, discriminator, label_rnn, gan_model_path = load_gan_model( model.pretrain_name, noise_size, model.feat_size, top_k=top_k, pool_mode=pool_mode, epoch=gan_epoch, ngh=ngh, ndh=ndh, critic_iters=critic_iters, lr=gan_lr, add_zero=add_zero, reg_ratio=reg_ratio, decoder=decoder) generator.to(device) discriminator.to(device) label_rnn.to(device) label_real_feats, _ = get_features_for_labels(model, syn_indices) model.init_output_fc(pretrain_label_emb) def get_rnn_emb(label_indices, labels=None): desc_x = code_desc[label_indices][:, :top_k] desc_m = code_mask[label_indices][:, :top_k] labels_rnn_emb = label_rnn.forward_enc(desc_x, desc_m, None, training=False, pool_mode=pool_mode) return torch.cat([labels, labels_rnn_emb], dim=1) def generate(labels): b = labels.size(0) noises = torch.randn(b, noise_size).to(labels.device) noises = Variable(noises) fake_feats = generator.forward(noises, labels) fake_feats = torch.relu(fake_feats) return fake_feats finetune_params = model.output_fc.parameters() optimizer = AdamW(finetune_params, lr=lr, weight_decay=l2_ratio, betas=(0.5, 0.999)) scheduler = get_scheduler(optimizer, finetune_epochs, ratios=(0.6, 0.9)) def synthesize(m, n): feats = [] labels = [] sample_num = n with torch.set_grad_enabled(False): m.eval() for code in syn_indices: syn_codes = [code] * sample_num syn_codes = torch.LongTensor(syn_codes).to(device) label = label_emb[syn_codes] label = get_rnn_emb(syn_codes, label) gen_feats = generate(labels=label) code_feats = gen_feats.data.cpu().numpy() if code in label_real_feats: code_feats = np.vstack([code_feats] + label_real_feats[code]) feats.append(code_feats) labels.append(np.ones(len(code_feats)) * code) feats = np.vstack(feats) labels = np.concatenate(labels) return feats.astype(np.float32), labels.astype(int) def inf_syn_data_sampler(x, y, b): while True: batches = simple_iterate_minibatch(x, y, b, shuffle=True) for batch in batches: yield batch code_data_list, train_x, train_y, train_masks, dev_x, dev_y, dev_masks = \ load_data(train_notes, train_labels, dev_notes, dev_labels, codes_to_targets, max_note_len) def inf_data_sampler(): while True: batches = iterate_minibatch(train_x, train_y, train_masks, eval_code_size, batch_size, shuffle=True) for batch in batches: yield batch syn_feats, syn_labels = synthesize(model, syn_num) syn_sampler = inf_syn_data_sampler(syn_feats, syn_labels, gan_batch_size) real_sampler = inf_data_sampler() n_data = len(syn_labels) n_batches = n_data // gan_batch_size log('Finetuning on GAN generated samples...') syn_indices = torch.LongTensor(syn_indices).to(device) dev_syn_indices = torch.LongTensor(dev_zero_shot_indices).to(device) \ if eval_zero else torch.LongTensor(dev_few_shot_indices).to(device) test_syn_indices = torch.LongTensor(eval_zero_shot_indices).to(device) \ if eval_zero else torch.LongTensor(test_few_shot_indices).to(device) best_f1 = -1 best_epoch = -1 def eval_finetune(eval_x, eva_y, eval_mask, eval_syn_indices): eval_losses = [] y_true = [] y_score = [] with torch.set_grad_enabled(False): model.eval() generator.eval() discriminator.eval() label_rnn.eval() label_rnn.eval() for batch in iterate_minibatch(eval_x, eva_y, eval_mask, eval_code_size, batch_size=eval_batch_size, shuffle=False): x, y, mask, _ = batch x, y, mask = x.to(device), y.to(device), mask.to(device) # forward pass logits, loss = model.forward(x, y, mask, label_indices=eval_syn_indices) probs = torch.sigmoid(logits[:, :eval_code_size]) # eval stats y_true.append(y[:, eval_syn_indices].cpu().numpy()) y_score.append(probs.cpu().numpy()) eval_losses.append(loss.mean().data.cpu().numpy()) y_score = np.vstack(y_score) y_true = np.vstack(y_true) metircs = all_metrics(y_score, y_true) return np.mean(eval_losses), metircs gan_model_path = gan_model_path.replace('model', 'npz') for epoch in range(finetune_epochs): with torch.set_grad_enabled(True): model.train() generator.eval() discriminator.eval() label_rnn.eval() train_losses = [] for _ in range(n_batches): for _ in range(neg_iters): x, y, mask, _ = next(real_sampler) x, y, mask = x.to(device), y.to(device), mask.to(device) # forward pass optimizer.zero_grad() _, loss = model.forward(x, y, mask, label_indices=syn_indices) loss.backward() optimizer.step() sample_indices = next(syn_sampler) syn_x = syn_feats[sample_indices] syn_y = syn_labels[sample_indices] syn_x, syn_y = torch.from_numpy(syn_x), torch.from_numpy(syn_y) syn_x, syn_y = syn_x.to(device), syn_y.to(device) optimizer.zero_grad() finetune_loss = model.forward_final(syn_x, syn_y) / syn_y.size(0) finetune_loss.backward() optimizer.step() train_losses.append(finetune_loss.data.cpu().numpy()) temp = copy.deepcopy(model.output_fc.weight.data) model.output_fc.weight.data.copy_(pretrain_label_emb.data) model.output_fc.weight.data[syn_indices] = temp[syn_indices] del temp dev_loss, dev_metrics = eval_finetune(dev_x, dev_y, dev_masks, dev_syn_indices) log(f"Epoch {epoch}, train loss={np.mean(train_losses):.4f}, dev loss={dev_loss:.4f}\n" ) log(f"\t{metric_string(dev_metrics)}\n") curr_f1 = dev_metrics['f1_micro'] if curr_f1 > best_f1 and epoch > finetune_epochs // 2: best_f1 = curr_f1 best_epoch = epoch # save the best final code classifier based on dev F1 score np.savez(f'{MODEL_DIR}/ft_z{eval_zero}_{gan_model_path}', model.output_fc.weight.data.cpu().numpy()) if lr >= 1e-4: scheduler.step(epoch) return best_f1, best_epoch
def train(lr=1e-3, batch_size=8, eval_batch_size=16, num_epochs=30, max_note_len=2000, loss='bce', gpu='cuda:1', save_model=True, graph_encoder='conv', class_margin=False, C=0.): pprint.pprint(locals(), stream=sys.stderr) device = torch.device(gpu if torch.cuda.is_available() else "cpu") train_data = Dataset('train') dev_data = Dataset('dev') test_data = Dataset('test') train_notes, train_labels = train_data.get_data() dev_notes, dev_labels = dev_data.get_data() log(f'Loaded {len(train_notes)} train data, {len(dev_notes)} dev data...') n_train_data = len(train_notes) train_codes = train_data.get_all_codes() dev_codes = dev_data.get_all_codes() test_codes = test_data.get_all_codes() all_codes = train_codes.union(dev_codes).union(test_codes) all_codes = sorted(all_codes) codes_to_targets = codes_to_index_labels(all_codes, False) dev_eval_code_size = len(codes_to_targets) frequent_codes, few_shot_codes, zero_shot_codes, codes_counter = split_codes_by_count( train_labels, dev_labels, train_codes, dev_codes) frequent_indices = code_to_indices(frequent_codes, codes_to_targets) few_shot_indices = code_to_indices(few_shot_codes, codes_to_targets) zero_shot_indices = code_to_indices(zero_shot_codes, codes_to_targets) extended_codes_to_targets, adj_matrix, codes_to_parents = load_adj_matrix( codes_to_targets) target_count = targets_to_count(codes_to_targets, codes_counter) if class_margin else None eval_code_size = len(codes_to_targets) log('Preloading data in memory...') train_x, train_y, train_masks = preload_data(train_notes, train_labels, codes_to_targets, max_note_len, save_path=CACHE_PATH) dev_x, dev_y, dev_masks = preload_data(dev_notes, dev_labels, codes_to_targets, max_note_len) log(f'Building model on {device}...') word_emb, code_idx_matrix, code_idx_mask = init_emb_and_code_input( extended_codes_to_targets) num_neighbors = torch.from_numpy( adj_matrix.sum(axis=1).astype(np.float32)).to(device) adj_matrix = torch.from_numpy(adj_matrix.astype( np.float32)).to_sparse().to(device) # L x L loss_fn = get_loss_fn(loss, reduction='sum') model = ConvLabelAttnModel(word_emb, code_idx_matrix, code_idx_mask, adj_matrix, num_neighbors, loss_fn, graph_encoder=graph_encoder, eval_code_size=eval_code_size, target_count=target_count, C=C) model.to(device) log(f'Evaluating on {len(frequent_indices)} frequent codes, ' f'{len(few_shot_indices)} few shot codes and {len(zero_shot_indices)} zero shot codes...' ) optimizer = get_optimizer(lr, model, weight_decay=1e-5) scheduler = get_scheduler(optimizer, num_epochs, ratios=(0.6, 0.85)) log(f'Start training with {eval_code_size} codes') best_dev_f1 = 0. for epoch in range(num_epochs): train_losses = [] # train one epoch with torch.set_grad_enabled(True): model.train() if gpu: torch.cuda.empty_cache() for batch in iterate_minibatch(train_x, train_y, train_masks, eval_code_size, batch_size, shuffle=True): x, y, mask, y_indices = batch x, y, mask, y_indices = x.to(device), y.to(device), mask.to( device), y_indices.to(device) optimizer.zero_grad() # forward pass logits, loss = model.forward(x, y, mask) # backward pass loss.mean().backward() optimizer.step() # train stats train_losses.append(loss.data.cpu().numpy()) dev_losses = [] y_true = [] y_score = [] with torch.set_grad_enabled(False): model.eval() if gpu: torch.cuda.empty_cache() for batch in iterate_minibatch(dev_x, dev_y, dev_masks, eval_code_size, batch_size=eval_batch_size, shuffle=False): x, y, mask, _ = batch x, y, mask = x.to(device), y.to(device), mask.to(device) y_true.append(y.cpu().numpy()[:, :dev_eval_code_size]) # forward pass logits, loss = model.forward(x, y, mask) probs = torch.sigmoid(logits[:, :dev_eval_code_size]) # eval stats y_score.append(probs.cpu().numpy()) dev_losses.append(loss.mean().data.cpu().numpy()) y_score = np.vstack(y_score) y_true = np.vstack(y_true) dev_f1 = log_eval_metrics(epoch, y_score, y_true, frequent_indices, few_shot_indices, zero_shot_indices, train_losses, dev_losses) if epoch > int( num_epochs * 0.75) and dev_f1 > best_dev_f1 and save_model: best_dev_f1 = dev_f1 torch.save(model.state_dict_to_save(), f"{MODEL_DIR}/{model.name}") # update lr scheduler.step(epoch) if save_model: torch.save(model.state_dict_to_save(), f"{MODEL_DIR}/final_{model.name}")
if curr_f1 > best_f1 and epoch > finetune_epochs // 2: best_f1 = curr_f1 best_epoch = epoch # save the best final code classifier based on dev F1 score np.savez(f'{MODEL_DIR}/ft_z{eval_zero}_{gan_model_path}', model.output_fc.weight.data.cpu().numpy()) if lr >= 1e-4: scheduler.step(epoch) return best_f1, best_epoch if __name__ == '__main__': finetune_config = get_finetune_config() log('Finetuning code classifier with GAN generated features...') finetune_on_gan(eval_zero=finetune_config.eval_zero, gpu=finetune_config.gpu, lr=finetune_config.lr, graph_encoder=finetune_config.graph_encoder, syn_num=finetune_config.syn_num, gan_batch_size=finetune_config.gan_batch_size, class_margin=finetune_config.class_margin, C=finetune_config.C, gan_epoch=finetune_config.gan_epoch, finetune_epochs=finetune_config.finetune_epochs, gan_lr=finetune_config.gan_lr, critic_iters=finetune_config.critic_iters, ndh=finetune_config.ndh, ngh=finetune_config.ngh, add_zero=finetune_config.add_zero,
def load_features(feat_path): log(f'Loading pretrained features from {feat_path}') with np.load(f'{FEATURE_DIR}/{feat_path}') as f: # features, labels, keywords, keywords masks fs, ls, kws, kwms = f['arr_0'], f['arr_1'], f['arr_2'], f['arr_3'] return fs, ls, kws, kwms
def __init__(self, word_emb, code_idx_matrix, code_idx_mask, adj_matrix, num_neighbors, loss_fn, num_filters=200, kernel_sizes=(10, ), eval_code_size=None, label_hidden_size=200, graph_encoder='conv', target_count=None, C=0.): super(ConvLabelAttnModel, self).__init__() self.name = f"convattn_{graph_encoder}gnn" self.n_nodes = len(code_idx_matrix) self.register_buffer('code_idx_matrix', code_idx_matrix) self.register_buffer('code_idx_mask', code_idx_mask) self.register_buffer('adj_matrix', adj_matrix) self.register_buffer('num_neighbors', num_neighbors) self.loss_fn = loss_fn self.eval_code_size = eval_code_size self.emb = EmbLayer(word_emb.shape[1], word_emb.shape[0], W=word_emb, freeze_emb=False) self.emb_drop = nn.Dropout(p=0.5) self.embed_size = self.emb.embed_size self.feat_size = self.embed_size self.attn_drop = nn.Dropout(p=0.2) graph_encoder = get_graph_encoder(graph_encoder) if graph_encoder is not None: log(f'Using {graph_encoder.__name__} for encoding ICD hierarchy...' ) self.graph_label_encoder = graph_encoder(self.embed_size, label_hidden_size, self.n_nodes) self.feat_size += label_hidden_size else: self.graph_label_encoder = None self.output_proj = nn.Linear(num_filters, self.feat_size) xavier_uniform_(self.output_proj.weight) self.conv_modules = nn.ModuleList() for kernel_size in kernel_sizes: self.conv_modules.append( nn.Conv1d(self.embed_size, num_filters, kernel_size=kernel_size)) xavier_uniform_(self.conv_modules[-1].weight) self.proj = nn.Linear(num_filters, self.embed_size) xavier_uniform_(self.proj.weight) if target_count is not None: class_margin = torch.from_numpy(target_count)**0.25 class_margin = class_margin.masked_fill(class_margin == 0, 1) self.register_buffer('class_margin', 1.0 / class_margin) self.C = C self.name += f'_cm{int(self.C)}' else: self.class_margin = None self.C = 0 self.name += '.model'