def inference(inference_file, vocabulary_size, args): if args.model == "rnn": model = RNNLanguageModel(vocabulary_size, args) elif args.model == "birnn": model = BiRNNLanguageModel(vocabulary_size, args) else: raise ValueError("Unknown model option {}.".format(args.model)) # Define training procedure global_step = tf.Variable(0, trainable=False) params = tf.trainable_variables() gradients = tf.gradients(model.loss, params) clipped_gradients, _ = tf.clip_by_global_norm(gradients, 10.0) #global_step = tf.train.get_global_step() learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, args.decay_steps, args.decay_rate, staircase=True) #learning_rate = tf.Print(learning_rate,[learning_rate],"learning_rate: ") optimizer = tf.train.AdamOptimizer(learning_rate) train_op = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=global_step) saver = tf.train.Saver(max_to_keep = 5) with tf.Session() as sess: def infer_step(batch_x): if isinstance(batch_x, tf.Tensor): batch_x = sess.run(batch_x) batch_x = [row.strip().split() for row in batch_x] batch_x = list(map(lambda x: list(map(lambda y:int(y),x)),batch_x)) feed_dict = {model.x: batch_x, model.keep_prob: args.keep_prob} logits = sess.run([model.logits], feed_dict=feed_dict)[0] scores = list(map(lambda x:list(map(lambda y:softmax(y),x)),logits)) return scores # Initialize all variables sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(args.model_dir) if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) scores = [] if isinstance(inference_file,str): batch_x = data_loader(inference_file,args.batch_size,1,args.shuffle) while True: try: score = infer_step(batch_x) scores+=(score) except tf.errors.OutOfRangeError: print('inference finished...') break elif isinstance(inference_file,(list,np.ndarray, np.generic)): batchs = batch_iter(inference_file,args.batch_size,1) for batch_x in batchs: score = infer_step(batch_x) scores+=(score) #scores = np.mean(scores) return scores
def evaluate_r2a(data, task, model, args, writer=None): ''' Only applied to mode=test_r2a. Generate r2a for target data data: the data to be tested on task: the name of the task model: a dictionary of networks args: the overall argument writer: a file object. If not none, will write the prediction result and the generated attention to the file ''' if writer: writer.write( 'task\tlabel\traw\trationale\tpred_att\tgold_att\trat_freq\n') for key in model.keys(): model[key].eval() # obtain an iterator to go through the test data batches = data_utils.data_loader(data, args.batch_size, shuffle=False) total = {} # Iterate over the test data. Concatenate all the results. for batch in batches: cur_res = evaluate_r2a_batch(model, task, batch, args, writer) # store results of current batch for key, value in cur_res.items(): if key not in total: total[key] = value else: total[key] = np.concatenate((total[key], value)) loss_p2g = np.mean(total['loss_p2g']) print("{:15s} {:s} {:.4f} {:s} {:.4f} {:s} {:.4f}".format( task, colored("l_p2g", "blue"), loss_p2g, colored("l_rat", "blue"), np.mean(total['loss_rationale']), colored("l_unf", "blue"), np.mean(total['loss_uniform']), )) return loss_p2g
def evaluate(test_data, model, args, roc=False): total_true = np.array([], dtype=int) total_pred = np.array([], dtype=int) total_out = np.array([], dtype=int) total_loss = [] batches = data_utils.data_loader(test_data, args.batch_size, 1) for batch in batches: true, pred, loss, out = evaluate_batch(model, batch, args) total_true = np.concatenate((total_true, true)) total_pred = np.concatenate((total_pred, pred)) total_out = np.concatenate((total_out, out)) total_loss.append(loss) loss_total = sum(total_loss)/len(total_loss) acc, f1, recall, precision = _compute_score( y_pred=total_pred, y_true=total_true) tpr = None if roc: fpr, tpr, thresholds = metrics.roc_curve( total_true, total_out, pos_label=1) mean_fpr = np.linspace(0, 1, 100) tpr = interp(mean_fpr, fpr, tpr) tpr[0] = 0.0 print("{}, {:s} {:.6f}, " "{:s} {:>7.4f}, {:s} {:>7.4f}, {:s} {:>7.4f}, {:s} {:>7.4f}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), colored("loss", "red"), loss_total, colored(" acc", "blue"), acc, colored("recall", "blue"), recall, colored("precision", "blue"), precision, colored("f1", "blue"), f1)) return loss_total, acc, recall, precision, f1, tpr
def test_perplexity(test_file, step): #test_batches = batch_iter(test_file, args.batch_size, 1) test_batch_x = data_loader(test_file, args.batch_size, 1) losses, iters = 0, 0 #while not sess.should_stop(): while True: try: batch_x = sess.run(test_batch_x) except tf.errors.OutOfRangeError: break batch_x = [row.strip().split() for row in batch_x] batch_x = list(map(lambda x: list(map(lambda y:int(y),x)),batch_x)) feed_dict = {model.x: batch_x, model.keep_prob: 1.0} summaries, loss = sess.run([summary_op, model.loss], feed_dict=feed_dict) test_summary_writer.add_summary(summaries, step) losses += loss iters += 1 return np.exp(losses / iters)
def evaluate_task(data, task, tar_data, model, optimizer, args, writer=None): ''' For mode = train_r2a and test_r2a, evaluate the network on a test data for a source task. For mode = train_clf, evaluate the network on a test data for the current task. data: the data to be tested on task: the task of the data tar_data: target data, used for evaluating l_wd when mode=train_r2a or test_r2a model: a dictionary of networks optimizer: the optimizer that updates the network weights, it can be none. Used for estimating l_wd args: the overall argument writer: a file object. If not none, will write the prediction result and the generated attention to the file ''' # write the header of the output file if writer: writer.write( 'task\tlabel\traw\trationale\tpred_att\tgold_att\trat_freq\n') # initialize the optimizer if optimizer is None: optimizer = {} optimizer['critic'] = torch.optim.Adam(filter( lambda p: p.requires_grad, model['critic'].parameters()), lr=args.lr) for key in model.keys(): model[key].eval() # if training or testing r2a, use target data to evaluate l_wd tar_batches = None if args.tar_dataset == '' else\ data_utils.data_loader(tar_data, args.batch_size, oneEpoch=False) src_batches = None if args.tar_dataset == '' else\ data_utils.data_loader(data, args.batch_size, oneEpoch=False) # obtain an iterator to go through the test data batches = data_utils.data_loader(data, args.batch_size, shuffle=False) total = {} # Iterate over the test data. Concatenate all the results. for batch in batches: cur_res = evaluate_batch(model, optimizer, task, batch, src_batches, tar_batches, args, writer) # store results of current batch for key, value in cur_res.items(): if key not in total: total[key] = value else: total[key] = np.concatenate((total[key], value)) # average loss across all batches loss_lbl = np.mean(total['loss_lbl']) loss_r2a = np.mean(total['loss_r2a']) loss_a2r = np.mean(total['loss_a2r']) loss_wd = np.mean(total['loss_wd']) loss_src_lm = np.mean(total['loss_src_lm']) loss_tar_lm = np.mean(total['loss_tar_lm']) loss_total = np.mean(total['loss_lbl'] + args.l_wd * total['loss_wd'] + args.l_r2a * total['loss_r2a'] + args.l_a2r * total['loss_a2r'] + args.l_lm * (total['loss_src_lm'] + total['loss_tar_lm'])) loss_encoder = np.mean(total['loss_lbl'] + args.l_wd * total['loss_wd'] + args.l_lm * (total['loss_src_lm'] + total['loss_tar_lm'])) loss_lbl_r2a = np.mean(total['loss_lbl'] + args.l_r2a * total['loss_r2a'] + args.l_a2r * total['loss_a2r']) print("{:15s} {:s} {:.4f}, {:s} {:.4f}, {:s} {:.4f} * {:.1e}, {:s} {:.4f} * {:.1e} {:s} {:.4f} * {:.1e},"\ " {:s} {:.4f} * {:.1e}, {:s} {:.4f} * {:.1e}".format( task, colored("l_tot", "red"), loss_total, colored("l_lbl", "red"), loss_lbl, colored("l_wd", "red"), loss_wd, args.l_wd, colored("l_src_lm", "red"), loss_src_lm, args.l_lm, colored("l_tar_lm", "red"), loss_tar_lm, args.l_lm, colored("l_r2a", "red"), loss_r2a, args.l_r2a, colored("l_a2r", "red"), loss_a2r, args.l_a2r)) acc, f1, recall, precision = -1, -1, -1, -1 if args.num_classes[task] > 1: acc, f1, recall, precision = _compute_score( y_pred=total['pred_lbl'], y_true=total['true_lbl'], num_classes=args.num_classes[task]) print( "{:15s} {:s} {:>7.4f}, {:s} {:>7.4f}, {:s} {:>7.4f}, {:s} {:>7.4f}" .format('', colored("acc", "blue"), acc, colored("recall", "blue"), recall, colored("precision", "blue"), precision, colored("f1", "blue"), f1)) print( metrics.confusion_matrix(y_true=total['true_lbl'], y_pred=total['pred_lbl'])) return { 'loss_lbl': loss_lbl, 'loss_r2a': loss_r2a, 'loss_lbl_r2a': loss_lbl_r2a, 'loss_a2r': loss_a2r, 'loss_wd': loss_wd, 'loss_src_lm': loss_src_lm, 'loss_tar_lm': loss_tar_lm, 'loss_encoder': loss_encoder, 'loss_total': loss_total, 'acc': acc, 'f1': f1, 'recall': recall, 'precision': precision, }
def train(train_data, dev_data, model, args): timestamp = str(int(time.time() * 1e7)) out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", timestamp)) print("Saving the model to {}\n".format(out_dir)) if not os.path.exists(out_dir): os.makedirs(out_dir) best = 100 best_path = "" sub_cycle = 0 optimizer, scheduler = _init_optimizer(model, args) tar_train_batches = None if (args.mode == 'train_clf' or args.mode == 'test_clf') else \ data_utils.data_loader(train_data[args.tar_dataset], args.batch_size, oneEpoch=False) src_unlbl_train_batches = None if (args.mode == 'train_clf' or args.mode == 'test_clf') else \ data_utils.data_loader(train_data[args.src_dataset[0]], args.batch_size, oneEpoch=False) src_train_batches = data_utils.data_dict_loader(train_data, args.src_dataset, args.batch_size) tar_dev_data = None if args.tar_dataset == '' else dev_data[ args.tar_dataset] ep = 1 while True: start = time.time() train_res = [] if args.dispatcher: for i in range(args.epoch_size): cur_res = train_batch(model, next(src_train_batches), src_unlbl_train_batches, tar_train_batches, optimizer, args) train_res.append(cur_res) else: for batch in tqdm(range(args.epoch_size), dynamic_ncols=True): cur_res = train_batch(model, next(src_train_batches), src_unlbl_train_batches, tar_train_batches, optimizer, args) train_res.append(cur_res) end = time.time() print("\n{}, Updates {:5d}, Time Cost: {} seconds".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), ep * args.epoch_size, end - start)) _print_train_res(train_res, args) # evaluate on dev set print('=== DEV ===') dev_res = [] for task in args.src_dataset: writer = open( os.path.join(out_dir, str(ep)) + '.' + task + '.out', 'w') cur_res = evaluate_task(dev_data[task], task, tar_dev_data, model, optimizer, args, writer=writer) writer.close() dev_res.append(cur_res) scheduler[task].step(cur_res['loss_lbl']) dev_res = print_dev_res(dev_res, args) # adjust the encoder loss based on avg. loss lbl plus avg. loss wd scheduler['encoder'].step(dev_res['loss_encoder']) # adjust the encoder loss based on avg. r2a loss scheduler['r2a'].step(dev_res['loss_r2a']) if (args.mode != 'train_clf' and dev_res['loss_lbl_r2a'] < best) or\ (args.mode == 'train_clf' and dev_res['loss_lbl'] < best): best = dev_res[ 'loss_lbl_r2a'] if args.mode != 'train_clf' else dev_res[ 'loss_lbl'] best_path = os.path.join(out_dir, str(ep)) model_utils.save_model(model, best_path) sub_cycle = 0 else: sub_cycle += 1 if sub_cycle == args.patience * 2: break ep += 1 print("End of training. Restore the best weights") model = model_utils.load_saved_model(best_path, args) print('=== BEST DEV ===') dev_res = [] for task in args.src_dataset: cur_res = evaluate_task(dev_data[task], task, tar_dev_data, model, None, args) dev_res.append(cur_res) dev_res = print_dev_res(dev_res, args) print("Deleting model snapshot") os.system("rm -rf {}/*".format(out_dir)) # delete model snapshot for space if args.save: print("Save the best model to director saved-runs") best_dir = os.path.abspath(os.path.join(os.path.curdir, "saved-runs", args.mode, \ "-".join(args.src_dataset) + '_' + args.tar_dataset + '_' + timestamp)) if not os.path.exists(best_dir): os.makedirs(best_dir) best_dir = os.path.join(best_dir, 'best') model_utils.save_model(model, best_dir) with open(best_dir + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr.upper(), value)) return dev_res, best_dir, model return dev_res, out_dir, model
tf_confusion_ph = tf.placeholder(tf.float32, shape=None, name='confusion_summary') # Create a scalar summary object for the loss so it can be displayed tf_loss_summary = tf.summary.scalar('loss', tf_loss_ph) tf_accuracy_summary = tf.summary.scalar('accuracy', tf_accuracy_ph) tf_confusion_summary = tf.summary.scalar('confusion acc', tf_confusion_ph) # Merge all summaries together performance_summaries = tf.summary.merge( [tf_loss_summary, tf_accuracy_summary, tf_confusion_summary]) # load iterator and setup model num_images = train_data.shape[0] num_val = validation_data.shape[0] loader_data = data_utils.data_loader(train_data, train_labels, num_images) batch_image, batch_label = loader_data.get_next() # with tf.device("/device:GPU:0"): model = architect.CNN(batch_image, batch_label) with graph.as_default(): # write graph writer = tf.summary.FileWriter(params.LOG_DIR, sess.graph) with tf.device("/device:GPU:0"): # initialize variable sess.run(tf.global_variables_initializer()) for epoch in range(1, params.EPOCHS + 1): print("[INFO] Epoch {}/{} - Batch Size {} - {} images".format( epoch, params.EPOCHS, params.BATCH_SIZE, num_images)) # set up data iterator for training iterator = int(num_images / params.BATCH_SIZE) feed_dict_train = {input: train_data, label: train_labels}
def train(train_data, dev_data, model, args): # get time stamp for snapshot path timestamp = str(int(time.time() * 1e7)) out_dir = os.path.abspath( os.path.join(os.path.curdir, "tmp-runs", timestamp)) print("Saving the model to {}\n".format(out_dir)) if not os.path.exists(out_dir): os.makedirs(out_dir) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', patience=args.patience, factor=0.1, verbose=True) best = 100 best_path = "" sub_cycle = 0 ep = 1 while True: start = time.time() batches = data_utils.data_loader(train_data, args.batch_size) if args.dispatcher: for batch in batches: train_batch(model, batch, optimizer, args) else: for batch in tqdm(batches, total=math.ceil( len(train_data['label']) / args.batch_size), dynamic_ncols=True): train_batch(model, batch, optimizer, args) end = time.time() print("{}, Epoch {:3d}, Time Cost: {} seconds, temperature: {:.4f}". format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), ep, end - start, args.temperature)) if ep % 10 == 0: print("Train:", end=" ") evaluate(train_data, model, args) print("Dev :", end=" ") # visualizing rationales during training, this will slow down training process writer = data_utils.generate_writer(os.path.join(out_dir, str(ep))) cur_loss, _, _, _, _ = evaluate(dev_data, model, args, writer) data_utils.close_writer(writer) # cur_loss, _, _, _, _ = evaluate(dev_data, model, args, None) scheduler.step(cur_loss) # auto adjust the lr when loss stop improving if cur_loss < best: best = cur_loss torch.save(model.state_dict(), os.path.join(out_dir, str(ep))) best_path = os.path.join(out_dir, str(ep)) print("Saved current best weights to {}\n".format(best_path)) sub_cycle = 0 else: sub_cycle += 1 if sub_cycle == args.patience * 2: break ep += 1 print("End of training. Restore the best weights.") model.load_state_dict(torch.load(best_path)) print("Best development performance during training") loss, acc, recall, precision, f1 = evaluate(dev_data, model, args) print("Deleting model snapshot") os.system("rm -rf {}/*".format(out_dir)) # delete model snapshot for space if args.save: print("Save the best model to director saved-runs") best_dir = os.path.abspath( os.path.join( os.path.curdir, "saved-runs", args.dataset + '_' + str(args.num_classes) + '_' + timestamp)) if not os.path.exists(best_dir): os.makedirs(best_dir) best_dir = os.path.join(best_dir, 'best') torch.save(model, best_dir) print("Best model is saved to {:s}".format(best_dir)) with open(best_dir + '_args.txt', 'w') as f: for attr, value in sorted(args.__dict__.items()): f.write("{}={}\n".format(attr.upper(), value)) return loss, acc, recall, precision, f1, best_dir
def evaluate(test_data, model, args, writer=None): total = {} batches = data_utils.data_loader(test_data, args.batch_size) for batch in batches: cur = evaluate_batch(model, batch, args, writer) # store results of current batch for key, value in cur.items(): if key not in total: total[key] = value else: total[key] = np.concatenate((total[key], value)) loss_lbl = np.mean(total['loss_lbl']) loss_selection = np.mean(total['loss_selection']) loss_variation = np.mean(total['loss_variation']) prob_selection = np.mean(total['prob_selection']) prob_variation = np.mean(total['prob_variation']) if args.num_classes == 1: print( "{} {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}" .format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), colored("loss", "red"), loss_lbl + args.l_selection * loss_selection + args.l_variation * loss_variation, colored("l_lbl", "red"), loss_lbl, colored("l_sel", "red"), loss_selection, colored("l_var", "red"), loss_variation, colored("p_sel", "red"), prob_selection, colored("p_var", "red"), prob_variation)) return loss_lbl, -1, -1, -1, -1 else: acc, f1, recall, precision = _compute_score( y_pred=total['pred_lbl'], y_true=total['true_lbl'], num_classes=args.num_classes) print( "{} {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}, {:s} {:.6f}\n" " " "{:s} {:>7.4f}, {:s} {:>7.4f}, {:s} {:>7.4f}, {:s} {:>7.4f}". format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), colored("loss", "red"), loss_lbl + args.l_selection * loss_selection + args.l_variation * loss_variation, colored("l_lbl", "red"), loss_lbl, colored("l_sel", "red"), loss_selection, colored("l_var", "red"), loss_variation, colored("p_sel", "red"), prob_selection, colored("p_var", "red"), prob_variation, colored(" acc", "blue"), acc, colored("recall", "blue"), recall, colored("precision", "blue"), precision, colored("f1", "blue"), f1)) return loss_lbl, acc, recall, precision, f1
def train(train_file, test_file, vocabulary_size, args): if args.model == "rnn": model = RNNLanguageModel(vocabulary_size, args) elif args.model == "birnn": model = BiRNNLanguageModel(vocabulary_size, args) else: raise ValueError("Unknown model option {}.".format(args.model)) # Define training procedure global_step = tf.Variable(0, trainable=False) params = tf.trainable_variables() gradients = tf.gradients(model.loss, params) clipped_gradients, _ = tf.clip_by_global_norm(gradients, 10.0) #global_step = tf.train.get_global_step() learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, args.decay_steps, args.decay_rate, staircase=True) #learning_rate = tf.Print(learning_rate,[learning_rate],"learning_rate: ") optimizer = tf.train.AdamOptimizer(learning_rate) train_op = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=global_step) # Summary loss_summary = tf.summary.scalar("loss", model.loss) summary_op = tf.summary.merge([loss_summary]) saver = tf.train.Saver(max_to_keep=5) with tf.Session() as sess: #with tf.train.MonitoredTrainingSession(checkpoint_dir='/tmp/checkpoints') as sess: def train_step(batch_x): batch_x = sess.run(batch_x) batch_x = [row.strip().split() for row in batch_x] batch_x = list( map(lambda x: list(map(lambda y: int(y), x)), batch_x)) feed_dict = {model.x: batch_x, model.keep_prob: args.keep_prob} _, step, summaries, loss = sess.run( [train_op, global_step, summary_op, model.loss], feed_dict=feed_dict) train_summary_writer.add_summary(summaries, step) if step % 100 == 1 and step != 1: print("step {0}: loss = {1}".format(step, loss)) def test_perplexity(test_file, step): #test_batches = batch_iter(test_file, args.batch_size, 1) iterator = data_loader(test_file, args.batch_size, 1) sess.run(iterator.initializer) losses, iters = 0, 0 #while not sess.should_stop(): while True: try: test_batch_x = iterator.get_next() batch_x = sess.run(test_batch_x) except tf.errors.OutOfRangeError: break batch_x = [row.strip().split() for row in batch_x] batch_x = list( map(lambda x: list(map(lambda y: int(y), x)), batch_x)) feed_dict = {model.x: batch_x, model.keep_prob: 1.0} summaries, loss = sess.run([summary_op, model.loss], feed_dict=feed_dict) test_summary_writer.add_summary(summaries, step) losses += loss iters += 1 return np.exp(losses / iters) #batches = batch_iter(train_data, args.batch_size, args.num_epochs) iterator = data_loader(train_file, args.batch_size, args.num_epochs) sess.run(iterator.initializer) # Initialize all variables sess.run(tf.global_variables_initializer()) train_summary_writer = tf.summary.FileWriter(args.model + "-train", sess.graph) test_summary_writer = tf.summary.FileWriter(args.model + "-test", sess.graph) ckpt = tf.train.get_checkpoint_state(args.model_dir) if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) #while not sess.should_stop(): while True: try: batch_x = iterator.get_next() train_step(batch_x) except tf.errors.OutOfRangeError: print('training finish...') break step = tf.train.global_step(sess, global_step) if step % args.check_steps == 1 & step != 1: print('step: ', step) perplexity = test_perplexity(test_file, step) print("\ttest perplexity: {}".format(perplexity)) print("\tlearning_rate: {}".format(sess.run(learning_rate))) checkpoint_path = os.path.join(args.model_dir, "MLE.ckpt") #global_step = tf.Print(global_step,[global_step],"global_step: ") saver.save(sess, checkpoint_path, global_step=global_step) print('Saving model at step %s' % step)
def train(train_data, dev_data, model, args): best = 100 sub_cycle = 0 best_model = None optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', patience=args.patience, factor=0.1, verbose=True) for ep in range(args.num_epochs): start = time.time() batches = data_utils.data_loader(train_data, args.batch_size, 1) if args.data_dir == '': for batch in tqdm( batches, total=math.ceil(len(train_data['label'])/args.batch_size)): train_batch(model, batch, optimizer, args) else: for batch in batches: train_batch(model, batch, optimizer, args) end = time.time() print("{}, Epoch {:3d}, Time Cost: {} seconds".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), ep, end-start)) if ep % 1 == 0: print("Train:", end=" ") evaluate(train_data, model, args) print("Dev :", end=" ") cur_loss, _, _, _, _, _ = evaluate(dev_data, model, args) scheduler.step(cur_loss) if cur_loss < best: best = cur_loss best_model = copy.deepcopy(model) sub_cycle = 0 else: sub_cycle += 1 if sub_cycle == args.patience*2: break sys.stdout.flush() print("End of training. Restore the best weights") model = copy.deepcopy(best_model) print("Best development performance during training") loss, acc, recall, precision, f1, _ = evaluate(dev_data, model, args) if args.save: # get time stamp for snapshot path timestamp = str(int(time.time() * 1e7)) out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", timestamp)) print("Saving the model to {}\n".format(out_dir)) if not os.path.exists(out_dir): os.makedirs(out_dir) print("Save the best model") torch.save(model, os.path.join(out_dir, "best")) print("Best model is saved to {:s}".format( os.path.join(out_dir, "best"))) return model, (loss, acc, recall, precision, f1)
parser.add_argument('--hidden_size', default=10, type=int) parser.add_argument('--lr', default=1e-3, type=float) parser.add_argument('--batch_size', default=512, type=int) parser.add_argument('--input_size', default=1, type=int) parser.add_argument('--num_layers', default=10, type=int) args = parser.parse_args() index = args.index window_size = args.window_size hidden_size = args.hidden_size learning_rate = args.lr batch_size = args.batch_size input_size = args.input_size num_layers = args.num_layers data = data_loader(window_size=window_size, index=index) train_time, train_data, train_label = data.train_data_loader() train_data = torch.tensor(train_data, dtype=torch.float32) train_label = torch.tensor(train_label, dtype=torch.float32) train_dataset = TensorDataset(train_data, train_label) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator = GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=1, device=device) generator = generator.to(device) generator.load_state_dict(torch.load('model.ckpt')) discriminator = Discriminator() discriminator = discriminator.to(device)
def run(e): e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25) if e.config.model.lower() == "js": model_class = models.JS_GAN elif e.config.model.lower() == "tv": model_class = models.TV_GAN elif e.config.model.lower() == "tv_pen": model_class = models.TV_GAN_penalty elif e.config.model.lower() == "freq": model_class = models.FREQ_GAN elif e.config.model.lower() == "freqb": model_class = models.FREQB_GAN elif e.config.model.lower() == "freqc": model_class = models.FREQC_GAN elif e.config.model.lower() == "freqc2": model_class = models.FREQC2_GAN elif e.config.model.lower() == "freqd": model_class = models.FREQD_GAN elif e.config.model.lower() == "freqe": model_class = models.FREQE_GAN elif e.config.model.lower() == "freqf": model_class = models.FREQF_GAN elif e.config.model.lower() == "freqg": model_class = models.FREQG_GAN data = data_utils.data_loader(mean1=e.config.m1, var1=e.config.v1, mean2=e.config.m2, var2=e.config.v2, mix_weight=[e.config.mi1, e.config.mi2], dim=e.config.isize, size=e.config.data_size, batch_size=e.config.bsize) # data = data_utils.data_loader( # mean1=e.config.m1, # var1=e.config.v1, # dim=e.config.isize, # size=e.config.data_size, # batch_size=e.config.bsize) e.log.info("data mean: " + str(data.data.mean(0))) model = model_class( input_dim=e.config.isize, # data_init=np.array([[0, 1]]).astype("float32"), data_init=(data.get_mean() + 1 / np.sqrt(e.config.isize)).astype('float32'), experiment=e) e.log.info(model) tot_it = curr_dstep = 0 for epoch in range(e.config.n_epoch): data_batch = data.prepare() # generator a list for it, d in enumerate(data_batch): model.train() tot_it += 1 if curr_dstep < e.config.ds: dloss = model.trainD(d) curr_dstep += 1 else: gloss = model.trainG(d) curr_dstep = 0 model.update_theta() if tot_it % e.config.eval_every == 0 or tot_it % len(data) == 0: model.eval() model.save() e.log.info("estimate: " + str(np.stack(model.thetalist).mean(0))) return [ np.stack(model.thetalist), model.netD.state_dict(), e.config.m1, data.data ]