def main(): if os.path.exists(FLAGS.checkpoint_path) == False: os.makedirs(FLAGS.checkpoint_path) checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt" latest_checkpoint_file_path = tf.train.latest_checkpoint( FLAGS.checkpoint_path) if os.path.exists(FLAGS.output_path) == False: os.makedirs(FLAGS.output_path) # Step 1: Construct the dataset op epoch_number = FLAGS.epoch_number if epoch_number <= 0: epoch_number = -1 train_buffer_size = FLAGS.train_batch_size * 3 validation_buffer_size = FLAGS.train_batch_size * 3 train_filename_list = [filename for filename in FLAGS.train_files.split(",")] train_filename_placeholder = tf.placeholder(tf.string, shape=[None]) train_dataset = tf.data.TFRecordDataset(train_filename_placeholder) train_dataset = train_dataset.map(parse_tfrecords_function).repeat( epoch_number).batch(FLAGS.train_batch_size).shuffle( buffer_size=train_buffer_size) train_dataset_iterator = train_dataset.make_initializable_iterator() batch_labels, batch_ids, batch_values = train_dataset_iterator.get_next() validation_filename_list = [ filename for filename in FLAGS.validation_files.split(",") ] validation_filename_placeholder = tf.placeholder(tf.string, shape=[None]) validation_dataset = tf.data.TFRecordDataset(validation_filename_placeholder) validation_dataset = validation_dataset.map(parse_tfrecords_function).repeat( ).batch(FLAGS.validation_batch_size).shuffle( buffer_size=validation_buffer_size) validation_dataset_iterator = validation_dataset.make_initializable_iterator( ) validation_labels, validation_ids, validation_values = validation_dataset_iterator.get_next( ) # Define the model logits = inference(batch_ids, batch_values, True) batch_labels = tf.to_int64(batch_labels) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=batch_labels) loss = tf.reduce_mean(cross_entropy, name="loss") global_step = tf.Variable(0, name="global_step", trainable=False) if FLAGS.enable_lr_decay: logging.info( "Enable learning rate decay rate: {}".format(FLAGS.lr_decay_rate)) starter_learning_rate = FLAGS.learning_rate learning_rate = tf.train.exponential_decay( starter_learning_rate, global_step, 100000, FLAGS.lr_decay_rate, staircase=True) else: learning_rate = FLAGS.learning_rate optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) tf.get_variable_scope().reuse_variables() # Define accuracy op for train data train_accuracy_logits = inference(batch_ids, batch_values, False) train_softmax = tf.nn.softmax(train_accuracy_logits) train_correct_prediction = tf.equal( tf.argmax(train_softmax, 1), batch_labels) train_accuracy = tf.reduce_mean( tf.cast(train_correct_prediction, tf.float32)) # Define auc op for train data batch_labels = tf.cast(batch_labels, tf.int32) sparse_labels = tf.reshape(batch_labels, [-1, 1]) derived_size = tf.shape(batch_labels)[0] indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1]) concated = tf.concat(axis=1, values=[indices, sparse_labels]) outshape = tf.stack([derived_size, FLAGS.label_size]) new_train_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0) _, train_auc = tf.contrib.metrics.streaming_auc(train_softmax, new_train_batch_labels) # Define accuracy op for validate data validate_accuracy_logits = inference(validation_ids, validation_values, False) validate_softmax = tf.nn.softmax(validate_accuracy_logits) validate_batch_labels = tf.to_int64(validation_labels) validate_correct_prediction = tf.equal( tf.argmax(validate_softmax, 1), validate_batch_labels) validate_accuracy = tf.reduce_mean( tf.cast(validate_correct_prediction, tf.float32)) # Define auc op for validate data validate_batch_labels = tf.cast(validate_batch_labels, tf.int32) sparse_labels = tf.reshape(validate_batch_labels, [-1, 1]) derived_size = tf.shape(validate_batch_labels)[0] indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1]) concated = tf.concat(axis=1, values=[indices, sparse_labels]) outshape = tf.stack([derived_size, FLAGS.label_size]) new_validate_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0) _, validate_auc = tf.contrib.metrics.streaming_auc(validate_softmax, new_validate_batch_labels) # Define inference op sparse_index = tf.placeholder(tf.int64, [None, 2]) sparse_ids = tf.placeholder(tf.int64, [None]) sparse_values = tf.placeholder(tf.float32, [None]) sparse_shape = tf.placeholder(tf.int64, [2]) inference_ids = tf.SparseTensor(sparse_index, sparse_ids, sparse_shape) inference_values = tf.SparseTensor(sparse_index, sparse_values, sparse_shape) inference_logits = inference(inference_ids, inference_values, False) inference_softmax = tf.nn.softmax(inference_logits) inference_op = tf.argmax(inference_softmax, 1) keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1]) keys = tf.identity(keys_placeholder) signature_def_map = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def_utils.build_signature_def( inputs={ "keys": utils.build_tensor_info(keys_placeholder), "indexs": utils.build_tensor_info(sparse_index), "ids": utils.build_tensor_info(sparse_ids), "values": utils.build_tensor_info(sparse_values), "shape": utils.build_tensor_info(sparse_shape) }, outputs={ "keys": utils.build_tensor_info(keys), "softmax": utils.build_tensor_info(inference_softmax), "prediction": utils.build_tensor_info(inference_op) }, method_name=signature_constants.PREDICT_METHOD_NAME) } # Initialize saver and summary saver = tf.train.Saver() tf.summary.scalar("loss", loss) tf.summary.scalar("train_accuracy", train_accuracy) tf.summary.scalar("train_auc", train_auc) tf.summary.scalar("validate_accuracy", validate_accuracy) tf.summary.scalar("validate_auc", validate_auc) summary_op = tf.summary.merge_all() init_op = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] # Create session to run with tf.Session() as sess: writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph) sess.run(init_op) sess.run( train_dataset_iterator.initializer, feed_dict={train_filename_placeholder: train_filename_list}) sess.run( validation_dataset_iterator.initializer, feed_dict={validation_filename_placeholder: validation_filename_list}) if FLAGS.mode == "train": # Restore session and start queue runner util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) start_time = datetime.datetime.now() try: while not coord.should_stop(): if FLAGS.benchmark_mode: sess.run(train_op) else: _, step = sess.run([train_op, global_step]) # Print state while training if step % FLAGS.steps_to_validate == 0: loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run( [ loss, train_accuracy, train_auc, validate_accuracy, validate_auc, summary_op ]) end_time = datetime.datetime.now() logging.info( "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}". format(end_time - start_time, step, loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value)) writer.add_summary(summary_value, step) saver.save(sess, checkpoint_file_path, global_step=step) start_time = end_time except tf.errors.OutOfRangeError: if FLAGS.benchmark_mode: print("Finish training for benchmark") exit(0) else: # Export the model after training util.save_model( FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) finally: coord.request_stop() coord.join(threads) elif FLAGS.mode == "save_model": if not util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path): logging.error("No checkpoint found, exit now") exit(1) util.save_model( FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) elif FLAGS.mode == "inference": if not util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path): logging.error("No checkpoint found, exit now") exit(1) # Load inference test data inference_result_file_name = "./inference_result.txt" inference_test_file_name = "./data/a8a_test.libsvm" labels = [] feature_ids = [] feature_values = [] feature_index = [] ins_num = 0 for line in open(inference_test_file_name, "r"): tokens = line.split(" ") labels.append(int(tokens[0])) feature_num = 0 for feature in tokens[1:]: feature_id, feature_value = feature.split(":") feature_ids.append(int(feature_id)) feature_values.append(float(feature_value)) feature_index.append([ins_num, feature_num]) feature_num += 1 ins_num += 1 # Run inference start_time = datetime.datetime.now() prediction, prediction_softmax = sess.run( [inference_op, inference_softmax], feed_dict={ sparse_index: feature_index, sparse_ids: feature_ids, sparse_values: feature_values, sparse_shape: [ins_num, FLAGS.feature_size] }) end_time = datetime.datetime.now() # Compute accuracy label_number = len(labels) correct_label_number = 0 for i in range(label_number): if labels[i] == prediction[i]: correct_label_number += 1 accuracy = float(correct_label_number) / label_number # Compute auc expected_labels = np.array(labels) predict_labels = prediction_softmax[:, 0] fpr, tpr, thresholds = metrics.roc_curve( expected_labels, predict_labels, pos_label=0) auc = metrics.auc(fpr, tpr) logging.info("[{}] Inference accuracy: {}, auc: {}".format( end_time - start_time, accuracy, auc)) # Save result into the file np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",") logging.info( "Save result to file: {}".format(inference_result_file_name)) elif FLAGS.mode == "inference_with_tfrecords": if not util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path): logging.error("No checkpoint found, exit now") exit(1) # Load inference test data inference_result_file_name = "./inference_result.txt" inference_test_file_name = "./data/a8a/a8a_test.libsvm.tfrecords" batch_feature_index = [] batch_labels = [] batch_ids = [] batch_values = [] ins_num = 0 # Read from TFRecords files for serialized_example in tf.python_io.tf_record_iterator( inference_test_file_name): # Get serialized example from file example = tf.train.Example() example.ParseFromString(serialized_example) label = example.features.feature["label"].float_list.value ids = example.features.feature["ids"].int64_list.value values = example.features.feature["values"].float_list.value #print("label: {}, features: {}".format(label, " ".join([str(id) + ":" + str(value) for id, value in zip(ids, values)]))) batch_labels.append(label) # Notice that using extend() instead of append() to flatten the values batch_ids.extend(ids) batch_values.extend(values) for i in xrange(len(ids)): batch_feature_index.append([ins_num, i]) ins_num += 1 # Run inference start_time = datetime.datetime.now() prediction, prediction_softmax = sess.run( [inference_op, inference_softmax], feed_dict={ sparse_index: batch_feature_index, sparse_ids: batch_ids, sparse_values: batch_values, sparse_shape: [ins_num, FLAGS.feature_size] }) end_time = datetime.datetime.now() # Compute accuracy label_number = len(batch_labels) correct_label_number = 0 for i in range(label_number): if batch_labels[i] == prediction[i]: correct_label_number += 1 accuracy = float(correct_label_number) / label_number # Compute auc expected_labels = np.array(batch_labels) predict_labels = prediction_softmax[:, 0] fpr, tpr, thresholds = metrics.roc_curve( expected_labels, predict_labels, pos_label=0) auc = metrics.auc(fpr, tpr) logging.info("[{}] Inference accuracy: {}, auc: {}".format( end_time - start_time, accuracy, auc)) # Save result into the file np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",") logging.info( "Save result to file: {}".format(inference_result_file_name))
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with ' f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group( backend=args.dist_backend, #init_method=args.dist_url, #world_size=util.get_world_size() ) assert (util.get_world_size() == dist.get_world_size()) g.logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, freeze_below=args.freeze_below) g.state.model.to(g.device) optimizer_setup(g.state) if args.checkpoint: if args.checkpoint_secondary: g.logger.info(f"restoring extra checkpoint") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint_secondary, args.optim_state_dict) g.logger.info(f"Restoring model from {args.checkpoint}" + f" and optimizer from {args.optim_state_dict}" if args. optim_state_dict else "") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint, args.optim_state_dict) else: g.state.model.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.word_emb.apply(weights_init) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer if g.state.args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer( optimizer, static_loss_scale=g.state.args.static_loss_scale, dynamic_loss_scale=g.state.args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.token_count}") model.train() log_tb('sizes/batch_size', args.batch_size) log_tb('sizes/seq_size', args.tgt_len) if g.state.partial_epoch: # reuse previously loaded tr_iter and states assert g.state.tr_iter is not None assert g.state.mems is not None else: g.state.tr_iter = g.corpus.get_dist_iterator( 'train', rank=util.get_global_rank(), max_rank=util.get_world_size(), bsz=args.batch_size, bptt=args.tgt_len, device=g.device, ext_len=args.ext_len, skip_files=g.args.skip_files) g.state.mems = tuple() g.state.last_epoch = epoch log_start_time = time.time() tokens_per_epoch = 0 for batch, (data, target, seq_len) in enumerate(g.state.tr_iter): # assert seq_len == data.shape[0] # for i in range(1, data.shape[0]): # assert torch.all(torch.eq(data[i], target[i - 1])) # break # print(g.state.token_count, data) if g.state.train_step % args.eval_interval == 0: evaluate_and_log(model, g.va_iter, 'val_short-mem-1', generate_text=False, reset_mems_interval=1) evaluate_and_log(model, g.va_iter, 'val_short-mem-2', generate_text=False, reset_mems_interval=2) evaluate_and_log(model, g.va_iter, 'val_short-mem-3', generate_text=False, reset_mems_interval=3) evaluate_and_log(model, g.va_iter, 'val') if g.va_custom_iter: evaluate_and_log(g.state.model, g.va_custom_iter, g.args.valid_custom, generate_text=False) batch_total = torch.tensor(data.shape[1]).to(g.device) if args.local: # TODO(y): factor out (need way to see if dist was inited) batch_total = batch_total.sum() else: batch_total = util.dist_sum_tensor( batch_total) # global batch size batch_total = util.toscalar(batch_total) should_log = (g.state.train_step < args.verbose_log_steps) or \ (g.state.train_step + 1) % args.log_interval == 0 model.zero_grad() ret = model(data, target, *g.state.mems) loss, g.state.mems = ret[0], ret[1:] loss: torch.Tensor = loss.float().mean().type_as(loss) with timeit('backwards', noop=not should_log): if args.fp16: optimizer.backward(loss) else: loss.backward() loss0 = util.toscalar(loss) util.record('loss', loss0) util.record('params', torch.sum(util.flat_param(model)).item()) losses.append(loss0) accumulated_loss += loss0 if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # step-wise learning rate annealing if hasattr(optimizer, 'overflow') and optimizer.overflow: g.logger.info("skipped iteration") else: if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if g.state.token_count < args.warmup_tokens: curr_lr = args.lr * float( g.state.token_count) / args.warmup_tokens optimizer.param_groups[0]['lr'] = curr_lr elif args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler.step(g.state.token_count // 1000 // 1000) else: g.state.scheduler.step(g.state.token_count) optimizer.step() g.state.train_step += 1 consumed_tokens = data.shape[0] * data.shape[1] world_size = int(os.environ.get("WORLD_SIZE", "8")) if world_size > 8: # correction factor for multiple machines consumed_tokens = consumed_tokens * (world_size // 8) tokens_per_epoch += consumed_tokens g.state.token_count += consumed_tokens g.token_count = g.state.token_count if g.state.token_count >= args.max_tokens: g.state.partial_epoch = True raise StopIteration # break out of parent train loop if should_log: elapsed_time = time.time() - log_start_time elapsed_steps = g.state.train_step - g.state.last_log_step # compute average loss over last logging interval cur_loss = accumulated_loss / elapsed_steps cur_loss_mean = util.dist_mean(cur_loss) log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \ f'| {batch:>6d} batches ' \ f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \ f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \ f'| loss {cur_loss:5.2f}' if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {cur_loss / math.log(2):9.5f}' else: log_str += f' | ppl {math.exp(cur_loss):9.3f}' g.logger.info(log_str) log_tb('learning/epoch', epoch) log_tb('_loss', cur_loss_mean) # the most important thing log_tb('learning/loss', cur_loss_mean) log_tb('learning/ppl', math.exp(cur_loss_mean)) # currently step timings are not synchronized in multi-machine # case (see #4). Can add torch.distributed.barrier() to get # more accurate timings, but this may add slowness. log_tb('times/step', 1000 * elapsed_time / elapsed_steps) current_lr = optimizer.param_groups[0]['lr'] log_tb('learning/lr', current_lr) # 32 is the "canonical" batch size linear_scaling_factor = batch_total / 32 # TODO(y): merge logic from master log_tb('learning/base_lr', current_lr / linear_scaling_factor) if args.optim == 'lamb': log_lamb_rs(optimizer, g.event_writer, g.state.token_count) time_per_batch = elapsed_time / elapsed_steps time_per_sample = time_per_batch / args.batch_size time_per_token = time_per_sample / args.tgt_len log_tb('times/batches_per_sec', 1 / time_per_batch) log_tb('times/samples_per_sec', 1 / time_per_sample) log_tb('times/tokens_per_sec', 1 / time_per_token) if str(g.device) == 'cuda': log_tb("memory/allocated_gb", torch.cuda.memory_allocated() / 1e9) log_tb("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9) log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9) log_tb("memory/max_cached_gb", torch.cuda.max_memory_cached() / 1e9) accumulated_loss = 0 log_start_time = time.time() g.state.last_log_step = g.state.train_step if args.checkpoint_each_epoch: g.logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{epoch}') if tokens_per_epoch == 0: logging.info("Zero tokens in last epoch, breaking") break g.state.partial_epoch = False except KeyboardInterrupt: g.logger.info('-' * 100) g.logger.info('Exiting from training early') except StopIteration: pass return losses
def main(): if os.path.exists(FLAGS.checkpoint_path) == False: os.makedirs(FLAGS.checkpoint_path) checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt" latest_checkpoint_file_path = tf.train.latest_checkpoint( FLAGS.checkpoint_path) if os.path.exists(FLAGS.output_path) == False: os.makedirs(FLAGS.output_path) # Step 1: Construct the dataset op epoch_number = FLAGS.epoch_number if epoch_number <= 0: epoch_number = -1 train_buffer_size = FLAGS.train_batch_size * 3 validation_buffer_size = FLAGS.train_batch_size * 3 train_filename_list = [ filename for filename in FLAGS.train_files.split(",") ] train_filename_placeholder = tf.placeholder(tf.string, shape=[None]) train_dataset = tf.data.TFRecordDataset(train_filename_placeholder) train_dataset = train_dataset.map(parse_tfrecords_function).repeat( epoch_number).batch( FLAGS.train_batch_size).shuffle(buffer_size=train_buffer_size) train_dataset_iterator = train_dataset.make_initializable_iterator() batch_labels, batch_ids, batch_values = train_dataset_iterator.get_next() validation_filename_list = [ filename for filename in FLAGS.validation_files.split(",") ] validation_filename_placeholder = tf.placeholder(tf.string, shape=[None]) validation_dataset = tf.data.TFRecordDataset( validation_filename_placeholder) validation_dataset = validation_dataset.map( parse_tfrecords_function).repeat().batch( FLAGS.validation_batch_size).shuffle( buffer_size=validation_buffer_size) validation_dataset_iterator = validation_dataset.make_initializable_iterator( ) validation_labels, validation_ids, validation_values = validation_dataset_iterator.get_next( ) # Define the model logits = inference(batch_ids, batch_values, True) batch_labels = tf.to_int64(batch_labels) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=batch_labels) loss = tf.reduce_mean(cross_entropy, name="loss") global_step = tf.Variable(0, name="global_step", trainable=False) if FLAGS.enable_lr_decay: logging.info("Enable learning rate decay rate: {}".format( FLAGS.lr_decay_rate)) starter_learning_rate = FLAGS.learning_rate learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 100000, FLAGS.lr_decay_rate, staircase=True) else: learning_rate = FLAGS.learning_rate optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) tf.get_variable_scope().reuse_variables() # Define accuracy op for train data train_accuracy_logits = inference(batch_ids, batch_values, False) train_softmax = tf.nn.softmax(train_accuracy_logits) train_correct_prediction = tf.equal(tf.argmax(train_softmax, 1), batch_labels) train_accuracy = tf.reduce_mean( tf.cast(train_correct_prediction, tf.float32)) # Define auc op for train data batch_labels = tf.cast(batch_labels, tf.int32) sparse_labels = tf.reshape(batch_labels, [-1, 1]) derived_size = tf.shape(batch_labels)[0] indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1]) concated = tf.concat(axis=1, values=[indices, sparse_labels]) outshape = tf.stack([derived_size, FLAGS.label_size]) new_train_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0) _, train_auc = tf.contrib.metrics.streaming_auc(train_softmax, new_train_batch_labels) # Define accuracy op for validate data validate_accuracy_logits = inference(validation_ids, validation_values, False) validate_softmax = tf.nn.softmax(validate_accuracy_logits) validate_batch_labels = tf.to_int64(validation_labels) validate_correct_prediction = tf.equal(tf.argmax(validate_softmax, 1), validate_batch_labels) validate_accuracy = tf.reduce_mean( tf.cast(validate_correct_prediction, tf.float32)) # Define auc op for validate data validate_batch_labels = tf.cast(validate_batch_labels, tf.int32) sparse_labels = tf.reshape(validate_batch_labels, [-1, 1]) derived_size = tf.shape(validate_batch_labels)[0] indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1]) concated = tf.concat(axis=1, values=[indices, sparse_labels]) outshape = tf.stack([derived_size, FLAGS.label_size]) new_validate_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0) _, validate_auc = tf.contrib.metrics.streaming_auc( validate_softmax, new_validate_batch_labels) # Define inference op sparse_index = tf.placeholder(tf.int64, [None, 2]) sparse_ids = tf.placeholder(tf.int64, [None]) sparse_values = tf.placeholder(tf.float32, [None]) sparse_shape = tf.placeholder(tf.int64, [2]) inference_ids = tf.SparseTensor(sparse_index, sparse_ids, sparse_shape) inference_values = tf.SparseTensor(sparse_index, sparse_values, sparse_shape) inference_logits = inference(inference_ids, inference_values, False) inference_softmax = tf.nn.softmax(inference_logits) inference_op = tf.argmax(inference_softmax, 1) keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1]) keys = tf.identity(keys_placeholder) signature_def_map = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def_utils.build_signature_def( inputs={ "keys": utils.build_tensor_info(keys_placeholder), "indexs": utils.build_tensor_info(sparse_index), "ids": utils.build_tensor_info(sparse_ids), "values": utils.build_tensor_info(sparse_values), "shape": utils.build_tensor_info(sparse_shape) }, outputs={ "keys": utils.build_tensor_info(keys), "softmax": utils.build_tensor_info(inference_softmax), "prediction": utils.build_tensor_info(inference_op) }, method_name=signature_constants.PREDICT_METHOD_NAME) } # Initialize saver and summary saver = tf.train.Saver() tf.summary.scalar("loss", loss) tf.summary.scalar("train_accuracy", train_accuracy) tf.summary.scalar("train_auc", train_auc) tf.summary.scalar("validate_accuracy", validate_accuracy) tf.summary.scalar("validate_auc", validate_auc) summary_op = tf.summary.merge_all() init_op = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] # Create session to run with tf.Session() as sess: writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph) sess.run(init_op) sess.run(train_dataset_iterator.initializer, feed_dict={train_filename_placeholder: train_filename_list}) sess.run(validation_dataset_iterator.initializer, feed_dict={ validation_filename_placeholder: validation_filename_list }) if FLAGS.mode == "train": # Restore session and start queue runner util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) start_time = datetime.datetime.now() try: while not coord.should_stop(): if FLAGS.benchmark_mode: sess.run(train_op) else: _, step = sess.run([train_op, global_step]) # Print state while training if step % FLAGS.steps_to_validate == 0: loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run( [ loss, train_accuracy, train_auc, validate_accuracy, validate_auc, summary_op ]) end_time = datetime.datetime.now() logging.info( "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}" .format(end_time - start_time, step, loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value)) writer.add_summary(summary_value, step) saver.save(sess, checkpoint_file_path, global_step=step) start_time = end_time except tf.errors.OutOfRangeError: if FLAGS.benchmark_mode: print("Finish training for benchmark") exit(0) else: # Export the model after training util.save_model(FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) finally: coord.request_stop() coord.join(threads) elif FLAGS.mode == "save_model": if not util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path): logging.error("No checkpoint found, exit now") exit(1) util.save_model(FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) elif FLAGS.mode == "inference": if not util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path): logging.error("No checkpoint found, exit now") exit(1) # Load inference test data inference_result_file_name = "./inference_result.txt" inference_test_file_name = "./data/a8a_test.libsvm" labels = [] feature_ids = [] feature_values = [] feature_index = [] ins_num = 0 for line in open(inference_test_file_name, "r"): tokens = line.split(" ") labels.append(int(tokens[0])) feature_num = 0 for feature in tokens[1:]: feature_id, feature_value = feature.split(":") feature_ids.append(int(feature_id)) feature_values.append(float(feature_value)) feature_index.append([ins_num, feature_num]) feature_num += 1 ins_num += 1 # Run inference start_time = datetime.datetime.now() prediction, prediction_softmax = sess.run( [inference_op, inference_softmax], feed_dict={ sparse_index: feature_index, sparse_ids: feature_ids, sparse_values: feature_values, sparse_shape: [ins_num, FLAGS.feature_size] }) end_time = datetime.datetime.now() # Compute accuracy label_number = len(labels) correct_label_number = 0 for i in range(label_number): if labels[i] == prediction[i]: correct_label_number += 1 accuracy = float(correct_label_number) / label_number # Compute auc expected_labels = np.array(labels) predict_labels = prediction_softmax[:, 0] fpr, tpr, thresholds = metrics.roc_curve(expected_labels, predict_labels, pos_label=0) auc = metrics.auc(fpr, tpr) logging.info("[{}] Inference accuracy: {}, auc: {}".format( end_time - start_time, accuracy, auc)) # Save result into the file np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",") logging.info( "Save result to file: {}".format(inference_result_file_name)) elif FLAGS.mode == "inference_with_tfrecords": if not util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path): logging.error("No checkpoint found, exit now") exit(1) # Load inference test data inference_result_file_name = "./inference_result.txt" inference_test_file_name = "./data/a8a/a8a_test.libsvm.tfrecords" batch_feature_index = [] batch_labels = [] batch_ids = [] batch_values = [] ins_num = 0 # Read from TFRecords files for serialized_example in tf.python_io.tf_record_iterator( inference_test_file_name): # Get serialized example from file example = tf.train.Example() example.ParseFromString(serialized_example) label = example.features.feature["label"].float_list.value ids = example.features.feature["ids"].int64_list.value values = example.features.feature["values"].float_list.value #print("label: {}, features: {}".format(label, " ".join([str(id) + ":" + str(value) for id, value in zip(ids, values)]))) batch_labels.append(label) # Notice that using extend() instead of append() to flatten the values batch_ids.extend(ids) batch_values.extend(values) for i in xrange(len(ids)): batch_feature_index.append([ins_num, i]) ins_num += 1 # Run inference start_time = datetime.datetime.now() prediction, prediction_softmax = sess.run( [inference_op, inference_softmax], feed_dict={ sparse_index: batch_feature_index, sparse_ids: batch_ids, sparse_values: batch_values, sparse_shape: [ins_num, FLAGS.feature_size] }) end_time = datetime.datetime.now() # Compute accuracy label_number = len(batch_labels) correct_label_number = 0 for i in range(label_number): if batch_labels[i] == prediction[i]: correct_label_number += 1 accuracy = float(correct_label_number) / label_number # Compute auc expected_labels = np.array(batch_labels) predict_labels = prediction_softmax[:, 0] fpr, tpr, thresholds = metrics.roc_curve(expected_labels, predict_labels, pos_label=0) auc = metrics.auc(fpr, tpr) logging.info("[{}] Inference accuracy: {}, auc: {}".format( end_time - start_time, accuracy, auc)) # Save result into the file np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",") logging.info( "Save result to file: {}".format(inference_result_file_name))
def main(): """ Train the TensorFlow models. """ # Get hyper-parameters if os.path.exists(FLAGS.checkpoint_path) == False: os.makedirs(FLAGS.checkpoint_path) checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt" latest_checkpoint_file_path = tf.train.latest_checkpoint( FLAGS.checkpoint_path) if os.path.exists(FLAGS.output_path) == False: os.makedirs(FLAGS.output_path) # Step 1: Construct the dataset op epoch_number = FLAGS.epoch_number if epoch_number <= 0: epoch_number = -1 train_buffer_size = FLAGS.train_batch_size * 3 validation_buffer_size = FLAGS.train_batch_size * 3 train_filename_list = [filename for filename in FLAGS.train_files.split(",")] train_filename_placeholder = tf.placeholder(tf.string, shape=[None]) if FLAGS.file_format == "tfrecords": train_dataset = tf.data.TFRecordDataset(train_filename_placeholder) train_dataset = train_dataset.map(parse_tfrecords_function).repeat( epoch_number).batch(FLAGS.train_batch_size).shuffle( buffer_size=train_buffer_size) elif FLAGS.file_format == "csv": # Skip the header or not train_dataset = tf.data.TextLineDataset(train_filename_placeholder) train_dataset = train_dataset.map(parse_csv_function).repeat( epoch_number).batch(FLAGS.train_batch_size).shuffle( buffer_size=train_buffer_size) train_dataset_iterator = train_dataset.make_initializable_iterator() train_features_op, train_label_op = train_dataset_iterator.get_next() validation_filename_list = [ filename for filename in FLAGS.validation_files.split(",") ] validation_filename_placeholder = tf.placeholder(tf.string, shape=[None]) if FLAGS.file_format == "tfrecords": validation_dataset = tf.data.TFRecordDataset( validation_filename_placeholder) validation_dataset = validation_dataset.map( parse_tfrecords_function).repeat(epoch_number).batch( FLAGS.validation_batch_size).shuffle( buffer_size=validation_buffer_size) elif FLAGS.file_format == "csv": validation_dataset = tf.data.TextLineDataset( validation_filename_placeholder) validation_dataset = validation_dataset.map(parse_csv_function).repeat( epoch_number).batch(FLAGS.validation_batch_size).shuffle( buffer_size=validation_buffer_size) validation_dataset_iterator = validation_dataset.make_initializable_iterator( ) validation_features_op, validation_label_op = validation_dataset_iterator.get_next( ) # Step 2: Define the model input_units = FLAGS.feature_size output_units = FLAGS.label_size logits = inference(train_features_op, input_units, output_units, True) if FLAGS.loss == "sparse_cross_entropy": cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=train_label_op) loss = tf.reduce_mean(cross_entropy, name="loss") elif FLAGS.loss == "cross_entropy": cross_entropy = tf.nn.cross_entropy_with_logits( logits=logits, labels=train_label_op) loss = tf.reduce_mean(cross_entropy, name="loss") elif FLAGS.loss == "mean_square": msl = tf.square(logits - train_label_op, name="msl") loss = tf.reduce_mean(msl, name="loss") global_step = tf.Variable(0, name="global_step", trainable=False) learning_rate = FLAGS.learning_rate if FLAGS.enable_lr_decay: logging.info( "Enable learning rate decay rate: {}".format(FLAGS.lr_decay_rate)) starter_learning_rate = FLAGS.learning_rate learning_rate = tf.train.exponential_decay( starter_learning_rate, global_step, 100000, FLAGS.lr_decay_rate, staircase=True) optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) # Need to re-use the Variables for training and validation tf.get_variable_scope().reuse_variables() # Define accuracy op and auc op for train train_accuracy_logits = inference(train_features_op, input_units, output_units, False) train_softmax_op, train_accuracy_op = model.compute_softmax_and_accuracy( train_accuracy_logits, train_label_op) train_auc_op = model.compute_auc(train_softmax_op, train_label_op, FLAGS.label_size) # Define accuracy op and auc op for validation validation_accuracy_logits = inference(validation_features_op, input_units, output_units, False) validation_softmax_op, validation_accuracy_op = model.compute_softmax_and_accuracy( validation_accuracy_logits, validation_label_op) validation_auc_op = model.compute_auc(validation_softmax_op, validation_label_op, FLAGS.label_size) # Define inference op inference_features = tf.placeholder( "float", [None, FLAGS.feature_size], name="features") inference_logits = inference(inference_features, input_units, output_units, False) inference_softmax_op = tf.nn.softmax( inference_logits, name="inference_softmax") inference_prediction_op = tf.argmax( inference_softmax_op, 1, name="inference_prediction") keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1], name="keys") keys_identity = tf.identity(keys_placeholder, name="inference_keys") signature_def_map = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def_utils.build_signature_def( inputs={ "keys": utils.build_tensor_info(keys_placeholder), "features": utils.build_tensor_info(inference_features) }, outputs={ "keys": utils.build_tensor_info(keys_identity), "prediction": utils.build_tensor_info(inference_prediction_op), }, method_name="tensorflow/serving/predictss"), "serving_detail": signature_def_utils.build_signature_def( inputs={ "keys": utils.build_tensor_info(keys_placeholder), "features": utils.build_tensor_info(inference_features) }, outputs={ "keys": utils.build_tensor_info(keys_identity), "prediction": utils.build_tensor_info(inference_prediction_op), "softmax": utils.build_tensor_info(inference_softmax_op), }, method_name="sdfas") } # Initialize saver and summary saver = tf.train.Saver() tf.summary.scalar("loss", loss) if FLAGS.scenario == "classification": tf.summary.scalar("train_accuracy", train_accuracy_op) tf.summary.scalar("train_auc", train_auc_op) tf.summary.scalar("validate_accuracy", validation_accuracy_op) tf.summary.scalar("validate_auc", validation_auc_op) summary_op = tf.summary.merge_all() init_op = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] # Step 3: Create session to run with tf.Session() as sess: writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph) sess.run(init_op) sess.run( [ train_dataset_iterator.initializer, validation_dataset_iterator.initializer ], feed_dict={ train_filename_placeholder: train_filename_list, validation_filename_placeholder: validation_filename_list }) if FLAGS.mode == "train": if FLAGS.resume_from_checkpoint: util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path) try: start_time = datetime.datetime.now() while True: if FLAGS.enable_benchmark: sess.run(train_op) else: _, global_step_value = sess.run([train_op, global_step]) # Step 4: Display training metrics after steps if global_step_value % FLAGS.steps_to_validate == 0: if FLAGS.scenario == "classification": loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run( [ loss, train_accuracy_op, train_auc_op, validation_accuracy_op, validation_auc_op, summary_op ]) end_time = datetime.datetime.now() logging.info( "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}". format(end_time - start_time, global_step_value, loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value)) elif FLAGS.scenario == "regression": loss_value, summary_value = sess.run([loss, summary_op]) end_time = datetime.datetime.now() logging.info("[{}] Step: {}, loss: {}".format( end_time - start_time, global_step_value, loss_value)) writer.add_summary(summary_value, global_step_value) saver.save( sess, checkpoint_file_path, global_step=global_step_value) start_time = end_time except tf.errors.OutOfRangeError: if FLAGS.enable_benchmark: logging.info("Finish training for benchmark") else: # Step 5: Export the model after training util.save_model( FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) elif FLAGS.mode == "savedmodel": if util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path) == False: logging.error("No checkpoint for exporting model, exit now") return util.save_model( FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) elif FLAGS.mode == "inference": if util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path) == False: logging.error("No checkpoint for inference, exit now") return # Load test data inference_result_file_name = FLAGS.inference_result_file inference_test_file_name = FLAGS.inference_data_file inference_data = np.genfromtxt(inference_test_file_name, delimiter=",") inference_data_features = inference_data[:, 0:9] inference_data_labels = inference_data[:, 9] # Run inference start_time = datetime.datetime.now() prediction, prediction_softmax = sess.run( [inference_prediction_op, inference_softmax_op], feed_dict={inference_features: inference_data_features}) end_time = datetime.datetime.now() # Compute accuracy label_number = len(inference_data_labels) correct_label_number = 0 for i in range(label_number): if inference_data_labels[i] == prediction[i]: correct_label_number += 1 accuracy = float(correct_label_number) / label_number # Compute auc y_true = np.array(inference_data_labels) y_score = prediction_softmax[:, 1] fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score, pos_label=1) auc = metrics.auc(fpr, tpr) logging.info("[{}] Inference accuracy: {}, auc: {}".format( end_time - start_time, accuracy, auc)) # Save result into the file np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",") logging.info( "Save result to file: {}".format(inference_result_file_name))
def main(): global global_token_count, event_writer, train_step, train_loss, last_log_step, \ best_val_loss, epoch, model if args.local_rank > 0: pass # skip shutdown when rank is explicitly set + not zero rank else: os.system('shutdown -c') if not args.local: logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}' ) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) # log model info n_all_param = sum([p.nelement() for p in model.parameters()]) log_tb('sizes/params', n_all_param) n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) log_tb('sizes/non_emb_params', n_nonemb_param) logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # optimizer if args.optim.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) elif args.optim.lower() == 'lamb': optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd) else: assert args.optim.lower() == 'adam' optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': scheduler = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) elif args.scheduler == 'constant': pass model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing if args.checkpoint: if global_rank == 0: util.restore_from_checkpoint(model=model, checkpoint_fn=args.checkpoint) model = model.to(device) if args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) #, find_unused_parameters=True) if global_rank == 0: event_writer = SummaryWriter(args.logdir) event_writer.add_text('args', str(args)) # test checkpoint writing if args.checkpoint_each_epoch: logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}') # Loop over epochs. train_step = 0 train_loss = 0 last_log_step = 0 best_val_loss = None va_iter, te_iter = [ corpus.get_dist_iterator(split, global_rank, max_rank, args.batch_size * 2, args.tgt_len, device=device, ext_len=args.ext_len) for split in ('valid', 'test') ] # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=1): train(va_iter, optimizer, scheduler) except KeyboardInterrupt: logger.info('-' * 100) logger.info('Exiting from training early') except StopIteration: pass # Eval one more time. evaluate_and_log(optimizer, va_iter, 'val', train_step=-1) # Load the best saved model. logger.info("Loading best checkpoint") model_file = os.path.join(args.logdir, 'model-best.pt') if os.path.exists(model_file): with open(model_file, 'rb') as model_f: with timeit('load'): if args.local: model = torch.load(model_f) else: model = torch.load(model_f, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) else: logger.warn('no model file, using current model for loss') # Run on test data. evaluate_and_log(optimizer, te_iter, 'test', -1)
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) g.logger.info(f"Distributed: success ({args.local_rank}/{dist.get_world_size()})") if args.load_state_fn: g.state = load_state(args.load_state_fn) g.logger.info(f"Restoring training from {args.load_state_fn}") else: g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) if args.checkpoint: util.restore_from_checkpoint(g.state.model, checkpoint_fn=args.checkpoint) else: g.state.model.apply(weights_init) g.state.model.word_emb.apply( weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.to(g.device) optimizer_setup(g.state) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if not g.args.load_state_fn: if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.
def main(): """ Train the TensorFlow models. """ # Get hyper-parameters if os.path.exists(FLAGS.checkpoint_path) == False: os.makedirs(FLAGS.checkpoint_path) checkpoint_file_path = FLAGS.checkpoint_path + "/checkpoint.ckpt" latest_checkpoint_file_path = tf.train.latest_checkpoint( FLAGS.checkpoint_path) if os.path.exists(FLAGS.output_path) == False: os.makedirs(FLAGS.output_path) # Step 1: Construct the dataset op epoch_number = FLAGS.epoch_number if epoch_number <= 0: epoch_number = -1 train_buffer_size = FLAGS.train_batch_size * 3 validation_buffer_size = FLAGS.train_batch_size * 3 train_filename_list = [ filename for filename in FLAGS.train_files.split(",") ] train_filename_placeholder = tf.placeholder(tf.string, shape=[None]) if FLAGS.file_format == "tfrecords": train_dataset = tf.data.TFRecordDataset(train_filename_placeholder) train_dataset = train_dataset.map(parse_tfrecords_function).repeat( epoch_number).batch( FLAGS.train_batch_size).shuffle(buffer_size=train_buffer_size) elif FLAGS.file_format == "csv": # Skip the header or not train_dataset = tf.data.TextLineDataset(train_filename_placeholder) train_dataset = train_dataset.map(parse_csv_function).repeat( epoch_number).batch( FLAGS.train_batch_size).shuffle(buffer_size=train_buffer_size) train_dataset_iterator = train_dataset.make_initializable_iterator() train_features_op, train_label_op = train_dataset_iterator.get_next() validation_filename_list = [ filename for filename in FLAGS.validation_files.split(",") ] validation_filename_placeholder = tf.placeholder(tf.string, shape=[None]) if FLAGS.file_format == "tfrecords": validation_dataset = tf.data.TFRecordDataset( validation_filename_placeholder) validation_dataset = validation_dataset.map( parse_tfrecords_function).repeat(epoch_number).batch( FLAGS.validation_batch_size).shuffle( buffer_size=validation_buffer_size) elif FLAGS.file_format == "csv": validation_dataset = tf.data.TextLineDataset( validation_filename_placeholder) validation_dataset = validation_dataset.map(parse_csv_function).repeat( epoch_number).batch(FLAGS.validation_batch_size).shuffle( buffer_size=validation_buffer_size) validation_dataset_iterator = validation_dataset.make_initializable_iterator( ) validation_features_op, validation_label_op = validation_dataset_iterator.get_next( ) # Step 2: Define the model input_units = FLAGS.feature_size output_units = FLAGS.label_size logits = inference(train_features_op, input_units, output_units, True) if FLAGS.loss == "sparse_cross_entropy": cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=train_label_op) loss = tf.reduce_mean(cross_entropy, name="loss") elif FLAGS.loss == "cross_entropy": cross_entropy = tf.nn.cross_entropy_with_logits(logits=logits, labels=train_label_op) loss = tf.reduce_mean(cross_entropy, name="loss") elif FLAGS.loss == "mean_square": msl = tf.square(logits - train_label_op, name="msl") loss = tf.reduce_mean(msl, name="loss") global_step = tf.Variable(0, name="global_step", trainable=False) learning_rate = FLAGS.learning_rate if FLAGS.enable_lr_decay: logging.info("Enable learning rate decay rate: {}".format( FLAGS.lr_decay_rate)) starter_learning_rate = FLAGS.learning_rate learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 100000, FLAGS.lr_decay_rate, staircase=True) optimizer = util.get_optimizer_by_name(FLAGS.optimizer, learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) # Need to re-use the Variables for training and validation tf.get_variable_scope().reuse_variables() # Define accuracy op and auc op for train train_accuracy_logits = inference(train_features_op, input_units, output_units, False) train_softmax_op, train_accuracy_op = model.compute_softmax_and_accuracy( train_accuracy_logits, train_label_op) train_auc_op = model.compute_auc(train_softmax_op, train_label_op, FLAGS.label_size) # Define accuracy op and auc op for validation validation_accuracy_logits = inference(validation_features_op, input_units, output_units, False) validation_softmax_op, validation_accuracy_op = model.compute_softmax_and_accuracy( validation_accuracy_logits, validation_label_op) validation_auc_op = model.compute_auc(validation_softmax_op, validation_label_op, FLAGS.label_size) # Define inference op inference_features = tf.placeholder("float", [None, FLAGS.feature_size], name="features") inference_logits = inference(inference_features, input_units, output_units, False) inference_softmax_op = tf.nn.softmax(inference_logits, name="inference_softmax") inference_prediction_op = tf.argmax(inference_softmax_op, 1, name="inference_prediction") keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1], name="keys") keys_identity = tf.identity(keys_placeholder, name="inference_keys") signature_def_map = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def_utils.build_signature_def( inputs={ "keys": utils.build_tensor_info(keys_placeholder), "features": utils.build_tensor_info(inference_features) }, outputs={ "keys": utils.build_tensor_info(keys_identity), "prediction": utils.build_tensor_info(inference_prediction_op), }, method_name="tensorflow/serving/predictss"), "serving_detail": signature_def_utils.build_signature_def( inputs={ "keys": utils.build_tensor_info(keys_placeholder), "features": utils.build_tensor_info(inference_features) }, outputs={ "keys": utils.build_tensor_info(keys_identity), "prediction": utils.build_tensor_info(inference_prediction_op), "softmax": utils.build_tensor_info(inference_softmax_op), }, method_name="sdfas") } # Initialize saver and summary saver = tf.train.Saver() tf.summary.scalar("loss", loss) if FLAGS.scenario == "classification": tf.summary.scalar("train_accuracy", train_accuracy_op) tf.summary.scalar("train_auc", train_auc_op) tf.summary.scalar("validate_accuracy", validation_accuracy_op) tf.summary.scalar("validate_auc", validation_auc_op) summary_op = tf.summary.merge_all() init_op = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] # Step 3: Create session to run with tf.Session() as sess: writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph) sess.run(init_op) sess.run( [ train_dataset_iterator.initializer, validation_dataset_iterator.initializer ], feed_dict={ train_filename_placeholder: train_filename_list, validation_filename_placeholder: validation_filename_list }) if FLAGS.mode == "train": if FLAGS.resume_from_checkpoint: util.restore_from_checkpoint(sess, saver, latest_checkpoint_file_path) try: start_time = datetime.datetime.now() while True: if FLAGS.enable_benchmark: sess.run(train_op) else: _, global_step_value = sess.run( [train_op, global_step]) # Step 4: Display training metrics after steps if global_step_value % FLAGS.steps_to_validate == 0: if FLAGS.scenario == "classification": loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run( [ loss, train_accuracy_op, train_auc_op, validation_accuracy_op, validation_auc_op, summary_op ]) end_time = datetime.datetime.now() logging.info( "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}" .format(end_time - start_time, global_step_value, loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value)) elif FLAGS.scenario == "regression": loss_value, summary_value = sess.run( [loss, summary_op]) end_time = datetime.datetime.now() logging.info("[{}] Step: {}, loss: {}".format( end_time - start_time, global_step_value, loss_value)) writer.add_summary(summary_value, global_step_value) saver.save(sess, checkpoint_file_path, global_step=global_step_value) start_time = end_time except tf.errors.OutOfRangeError: if FLAGS.enable_benchmark: logging.info("Finish training for benchmark") else: # Step 5: Export the model after training util.save_model(FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) elif FLAGS.mode == "savedmodel": if util.restore_from_checkpoint( sess, saver, latest_checkpoint_file_path) == False: logging.error("No checkpoint for exporting model, exit now") return util.save_model(FLAGS.model_path, FLAGS.model_version, sess, signature_def_map, is_save_graph=False) elif FLAGS.mode == "inference": if util.restore_from_checkpoint( sess, saver, latest_checkpoint_file_path) == False: logging.error("No checkpoint for inference, exit now") return # Load test data inference_result_file_name = FLAGS.inference_result_file inference_test_file_name = FLAGS.inference_data_file inference_data = np.genfromtxt(inference_test_file_name, delimiter=",") inference_data_features = inference_data[:, 0:9] inference_data_labels = inference_data[:, 9] # Run inference start_time = datetime.datetime.now() prediction, prediction_softmax = sess.run( [inference_prediction_op, inference_softmax_op], feed_dict={inference_features: inference_data_features}) end_time = datetime.datetime.now() # Compute accuracy label_number = len(inference_data_labels) correct_label_number = 0 for i in range(label_number): if inference_data_labels[i] == prediction[i]: correct_label_number += 1 accuracy = float(correct_label_number) / label_number # Compute auc y_true = np.array(inference_data_labels) y_score = prediction_softmax[:, 1] fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score, pos_label=1) auc = metrics.auc(fpr, tpr) logging.info("[{}] Inference accuracy: {}, auc: {}".format( end_time - start_time, accuracy, auc)) # Save result into the file np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",") logging.info( "Save result to file: {}".format(inference_result_file_name))