def main(_=None): FLAGS = flags.FLAGS # pylint: disable=invalid-name,redefined-outer-name config = FLAGS FLAGS.__dict__['config'] = config # Build the graph with tf.Graph().as_default(): model_dict = model_config.get(FLAGS) data_dict = data_config.get(FLAGS) model = model_dict.model trainset = data_dict.trainset validset = data_dict.validset # Optimisation target validset = tools.maybe_convert_dataset(validset) trainset = tools.maybe_convert_dataset(trainset) t1 = model(trainset) t2 = model(validset) sess = tf.Session() saver = tf.train.Saver() saver.restore(sess, FLAGS.snapshot) if config.dataset == 'mnist': _collect_results(sess, _capsule.save_op, validset, 10000 // FLAGS.batch_size) _collect_results(sess, _capsule.save_op, trainset, 60000 // FLAGS.batch_size) elif config.dataset == 'svhn': _collect_results(sess, _capsule.save_op, validset, 26032 // FLAGS.batch_size) _collect_results(sess, _capsule.save_op, trainset, 73257 // FLAGS.batch_size) elif config.dataset == 'cifar10': _collect_results(sess, _capsule.save_op, validset, 10000 // FLAGS.batch_size) _collect_results(sess, _capsule.save_op, trainset, 50000 // FLAGS.batch_size)
def main(_=None): FLAGS = flags.FLAGS # pylint: disable=invalid-name,redefined-outer-name config = FLAGS FLAGS.__dict__['config'] = config FLAGS.logdir = FLAGS.logdir.format(name=FLAGS.name) logdir = FLAGS.logdir logging.info('logdir: %s', logdir) if os.path.exists(logdir) and FLAGS.overwrite: logging.info('"overwrite" is set to True. Deleting logdir at "%s".', logdir) shutil.rmtree(logdir) # Build the graph with tf.Graph().as_default(): model_dict = model_config.get(FLAGS) data_dict = data_config.get(FLAGS) lr = model_dict.lr opt = model_dict.opt model = model_dict.model trainset = data_dict.trainset validset = data_dict.validset lr = tf.convert_to_tensor(lr) tf.summary.scalar('learning_rate', lr) # Training setup global_step = tf.train.get_or_create_global_step() # Optimisation target validset = tools.maybe_convert_dataset(validset) trainset = tools.maybe_convert_dataset(trainset) target, gvs = model.make_target(trainset, opt) if gvs is None: gvs = opt.compute_gradients(target) suppress_inf_and_nans = (config.grad_value_clip > 0 or config.grad_norm_clip > 0) report = tools.gradient_summaries(gvs, suppress_inf_and_nans) report['target'] = target valid_report = dict() gvs = tools.clip_gradients(gvs, value_clip=config.grad_value_clip, norm_clip=config.grad_norm_clip) try: report.update(model.make_report(trainset)) valid_report.update(model.make_report(validset)) except AttributeError: logging.warning('Model %s has no "make_report" method.', str(model)) raise plot_dict, plot_params = None, None if config.plot: try: plot_dict, plot_params = model.make_plot(trainset, 'train') valid_plot, valid_params = model.make_plot(validset, 'valid') plot_dict.update(valid_plot) if plot_params is not None: plot_params.update(valid_params) except AttributeError: logging.warning('Model %s has no "make_plot" method.', str(model)) report = tools.scalar_logs(report, config.ema, 'train', global_update=config.global_ema_update) report['lr'] = lr valid_report = tools.scalar_logs( valid_report, config.ema, 'valid', global_update=config.global_ema_update) reports_keys = sorted(report.keys()) def _format(k): if k in ('lr', 'learning_rate'): return '.2E' return '.3f' report_template = ', '.join(['{}: {}{}:{}{}'.format( k, '{', k, _format(k), '}') for k in reports_keys]) logging.info('Trainable variables:') tools.log_variables_by_scope() # inspect gradients for g, v in gvs: if g is None: logging.warning('No gradient for variable: %s.', v.name) tools.log_num_params() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if FLAGS.check_numerics: update_ops += [tf.add_check_numerics_ops()] with tf.control_dependencies(update_ops): train_step = opt.apply_gradients(gvs, global_step=global_step) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True with tf.train.SingularMonitoredSession( hooks=create_hooks(FLAGS, plot_dict, plot_params), checkpoint_dir=logdir, config=sess_config) as sess: train_itr, _ = sess.run([global_step, update_ops]) train_tensors = [global_step, train_step] report_tensors = [report, valid_report] all_tensors = report_tensors + train_tensors while train_itr < config.max_train_steps: if train_itr % config.report_loss_steps == 0: report_vals, valid_report_vals, train_itr, _ = sess.run(all_tensors) logging.info('') logging.info('train:') logging.info('#%s: %s', train_itr, report_template.format(**report_vals)) logging.info('valid:') valid_logs = dict(report_vals) valid_logs.update(valid_report_vals) logging.info('#%s: %s', train_itr, report_template.format(**valid_logs)) vals_to_check = list(report_vals.values()) if (np.isnan(vals_to_check).any() or np.isnan(vals_to_check).any()): logging.fatal('NaN in reports: %s; breaking...', report_template.format(**report_vals)) else: train_itr, _ = sess.run(train_tensors)
def main(_=None): FLAGS = flags.FLAGS # pylint: disable=invalid-name,redefined-outer-name config = FLAGS FLAGS.__dict__['config'] = config # Build the graph with tf.Graph().as_default(): model_dict = model_config.get(FLAGS) data_dict = data_config.get(FLAGS) model = model_dict.model trainset = data_dict.trainset validset = data_dict.validset # Optimisation target validset = tools.maybe_convert_dataset(validset) trainset = tools.maybe_convert_dataset(trainset) train_tensors = model(trainset) valid_tensors = model(validset) sess = tf.Session() saver = tf.train.Saver() saver.restore(sess, FLAGS.snapshot) valid_results = _collect_results(sess, valid_tensors, validset, 10000 // FLAGS.batch_size) train_results = _collect_results(sess, train_tensors, trainset, 60000 // FLAGS.batch_size) results = AttrDict(train=train_results, valid=valid_results) # Linear classification print('Linear classification accuracy:') for k, v in results.items(): print('\t{}: prior={:.04f}, posterior={:.04f}'.format( k, v.prior_acc.mean(), v.posterior_acc.mean())) # Unsupervised classification via clustering print('Bipartite matching classification accuracy:') for field in 'posterior_pres prior_pres'.split(): kmeans = sklearn.cluster.KMeans( n_clusters=10, precompute_distances=True, n_jobs=-1, max_iter=1000, ).fit(results.train[field]) train_acc = cluster_classify(results.train[field], results.train.label, 10, kmeans) valid_acc = cluster_classify(results.valid[field], results.valid.label, 10, kmeans) print('\t{}: train_acc={:.04f}, valid_acc={:.04f}'.format(field, train_acc, valid_acc)) checkpoint_folder = osp.dirname(FLAGS.snapshot) figure_filename = osp.join(checkpoint_folder, FLAGS.tsne_figure_name) print('Savign TSNE plot at "{}"'.format(figure_filename)) make_tsne_plot(valid_results.posterior_pres, valid_results.label, figure_filename)
def main(_=None): FLAGS = flags.FLAGS # pylint: disable=invalid-name,redefined-outer-name config = FLAGS FLAGS.__dict__['config'] = config FLAGS.logdir = FLAGS.logdir.format(name=FLAGS.name) logdir = FLAGS.logdir logging.info('logdir: %s', logdir) if os.path.exists(logdir) and FLAGS.overwrite: logging.info('"overwrite" is set to True. Deleting logdir at "%s".', logdir) shutil.rmtree(logdir) fig_time = np.zeros([600]) fig_train = np.zeros([600]) fig_valid = np.zeros([600]) # Build the graph with tf.Graph().as_default(): model_dict = model_config.get(FLAGS) data_dict = data_config.get(FLAGS) lr = model_dict.lr opt = model_dict.opt model = model_dict.model trainset = data_dict.trainset validset = data_dict.validset lr = tf.convert_to_tensor(lr) tf.summary.scalar('learning_rate', lr) # Training setup global_step = tf.train.get_or_create_global_step() # Optimisation target validset = tools.maybe_convert_dataset(validset) trainset = tools.maybe_convert_dataset(trainset) target, gvs = model.make_target(trainset, opt) if gvs is None: print('gvs is none') gvs = opt.compute_gradients(target) suppress_inf_and_nans = (config.grad_value_clip > 0 or config.grad_norm_clip > 0) report = tools.gradient_summaries(gvs, suppress_inf_and_nans) report['target'] = target valid_report = dict() gvs = tools.clip_gradients(gvs, value_clip=config.grad_value_clip, norm_clip=config.grad_norm_clip) try: report.update(model.make_report(trainset)) valid_report.update(model.make_report(validset)) except AttributeError: logging.warning('Model %s has no "make_report" method.', str(model)) raise plot_dict, plot_params = None, None if config.plot: try: plot_dict, plot_params = model.make_plot(trainset, 'train') valid_plot, valid_params = model.make_plot(validset, 'valid') plot_dict.update(valid_plot) if plot_params is not None: plot_params.update(valid_params) except AttributeError: logging.warning('Model %s has no "make_plot" method.', str(model)) report = tools.scalar_logs(report, config.ema, 'train', global_update=config.global_ema_update) report['lr'] = lr valid_report = tools.scalar_logs( valid_report, config.ema, 'valid', global_update=config.global_ema_update) reports_keys = sorted(report.keys()) def _format(k): if k in ('lr', 'learning_rate'): return '.2E' return '.3f' report_template = ', '.join([ '{}: {}{}:{}{}'.format(k, '{', k, _format(k), '}') for k in reports_keys ]) logging.info('Trainable variables:') tools.log_variables_by_scope() # inspect gradients for g, v in gvs: if g is None: logging.warning('No gradient for variable: %s.', v.name) tools.log_num_params() update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if FLAGS.check_numerics: update_ops += [tf.add_check_numerics_ops()] with tf.control_dependencies(update_ops): train_step = opt.apply_gradients(gvs, global_step=global_step) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True #run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #run_metadata = tf.RunMetadata() with tf.train.SingularMonitoredSession(hooks=create_hooks( FLAGS, plot_dict, plot_params), checkpoint_dir=logdir, config=sess_config) as sess: train_itr, _ = sess.run([global_step, update_ops]) train_tensors = [global_step, train_step] report_tensors = [report, valid_report] all_tensors = report_tensors + train_tensors start_time = time.time() while train_itr < config.max_train_steps: if train_itr % config.report_loss_steps == 0: i = (int)(train_itr / config.report_loss_steps) report_vals, valid_report_vals, train_itr, _ = sess.run( all_tensors) logging.info('') logging.info('train:') logging.info('#%s: %s', train_itr, report_template.format(**report_vals)) end_time = time.time() fig_time[i] = end_time - start_time logging.info('valid:') valid_logs = dict(report_vals) fig_train[i] = valid_logs['best_cls_acc'] valid_logs.update(valid_report_vals) logging.info('#%s: %s', train_itr, report_template.format(**valid_logs)) fig_valid[i] = valid_logs['best_cls_acc'] vals_to_check = list(report_vals.values()) if (np.isnan(vals_to_check).any() or np.isnan(vals_to_check).any()): logging.fatal('NaN in reports: %s; breaking...', report_template.format(**report_vals)) else: '''if train_itr == 10030: print('here 10030') train_itr, _ = sess.run(train_tensors, options=run_options, run_metadata=run_metadata) tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() tl_str = 'trainTL_' + config.dataset + '_' + str(config.batch_size) with open(tl_str, 'w') as f: f.write(ctf) break else:''' train_itr, _ = sess.run(train_tensors) fig_time_str = config.dataset + '_' + str(config.batch_size) + '_time.npy' fig_train_str = config.dataset + '_' + str( config.batch_size) + '_train.npy' fig_valid_str = config.dataset + '_' + str( config.batch_size) + '_valid.npy' np.save(fig_time_str, fig_time) np.save(fig_train_str, fig_train) np.save(fig_valid_str, fig_valid) print('here we can plot the curve') fig, ax1 = plt.subplots() ax2 = ax1.twinx() lns1 = ax1.plot(fig_time, fig_train, label="train_accuracy") lns2 = ax2.plot(fig_time, fig_valid, 'r', label="valid_accuracy") ax1.set_xlabel('time') ax1.set_ylabel('accuracy') lns = lns1 + lns2 labels = ["train", "valid"] plt.legend(lns, labels, loc=7) plt.savefig('./time_accuracy.jpg')