def train(sess, model, hps, logdir): _print(hps) _print('Starting training. Logging to', logdir) _print( 'epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg' ) # Train sess.graph.finalize() n_processed = 0 n_images = 0 train_time = 0.0 test_loss_best = 999999 if hvd.rank() == 0: train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__) test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__) tcurr = time.time() for epoch in range(1, hps.epochs): t = time.time() train_results = [] for it in range(hps.train_its): # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs. lr = hps.lr * min(1., n_processed / (hps.n_train * hps.epochs_warmup)) # Run a training step synchronously. _t = time.time() train_results += [model.train(lr)] if hps.verbose and hvd.rank() == 0: _print(n_processed, time.time() - _t, train_results[-1]) sys.stdout.flush() # Images seen wrt anchor resolution n_processed += hvd.size() * hps.n_batch_train # Actual images seen at current resolution n_images += hvd.size() * hps.local_batch_train train_results = np.mean(np.asarray(train_results), axis=0) dtrain = time.time() - t ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain train_time += dtrain if hvd.rank() == 0: train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(train_time), **process_results(train_results)) if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0: test_results = [] msg = '' t = time.time() # model.polyak_swap() if epoch % hps.epochs_full_valid == 0: # Full validation run for it in range(hps.full_test_its): test_results += [model.test()] test_results = np.mean(np.asarray(test_results), axis=0) print(hps.full_test_its) print(test_results.shape) if hvd.rank() == 0: test_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, **process_results(test_results)) # Save checkpoint if test_results[0] < test_loss_best: test_loss_best = test_results[0] model.save(logdir + "model_best_loss.ckpt") msg += ' *' dtest = time.time() - t # Sample t = time.time() # if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0: # visualise(epoch) dsample = time.time() - t if hvd.rank() == 0: dcurr = time.time() - tcurr tcurr = time.time() _print( epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format( ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg) # model.polyak_swap() if hvd.rank() == 0: _print("Finished!")
def train(sess, model, hps, logdir, visualise): _print(hps) _print('Starting training. Logging to', logdir) _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg') # Train sess.graph.finalize() n_processed = 0 n_images = 0 train_time = 0.0 test_loss_best = 999999 if hvd.rank() == 0: train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__) test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__) tcurr = time.time() for epoch in range(1, hps.epochs): t = time.time() train_results = [] for it in range(hps.train_its): # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs. lr = hps.lr * min(1., n_processed / (hps.n_train * hps.epochs_warmup)) # Run a training step synchronously. _t = time.time() train_results += [model.train(lr)] if hps.verbose and hvd.rank() == 0: _print(n_processed, time.time()-_t, train_results[-1]) sys.stdout.flush() # Images seen wrt anchor resolution n_processed += hvd.size() * hps.n_batch_train # Actual images seen at current resolution n_images += hvd.size() * hps.local_batch_train train_results = np.mean(np.asarray(train_results), axis=0) dtrain = time.time() - t ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain train_time += dtrain if hvd.rank() == 0: train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int( train_time), **process_results(train_results)) if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0: test_results = [] msg = '' t = time.time() # model.polyak_swap() if epoch % hps.epochs_full_valid == 0: # Full validation run for it in range(hps.full_test_its): test_results += [model.test()] test_results = np.mean(np.asarray(test_results), axis=0) if hvd.rank() == 0: test_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, **process_results(test_results)) # Save checkpoint if test_results[0] < test_loss_best: test_loss_best = test_results[0] model.save(logdir+"model_best_loss.ckpt") msg += ' *' dtest = time.time() - t # Sample t = time.time() if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0: visualise(epoch) dsample = time.time() - t if hvd.rank() == 0: dcurr = time.time() - tcurr tcurr = time.time() _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format( ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg) # model.polyak_swap() if hvd.rank() == 0: _print("Finished!")
def main(hps): # Initialize Horovod. hvd.init() # Create tensorflow session sess = tensorflow_session() # Download and load dataset. tf.set_random_seed(hvd.rank() + hvd.size() * hps.seed) np.random.seed(hvd.rank() + hvd.size() * hps.seed) # Get data and set train_its and valid_its train_iterator, test_iterator, data_init = get_data(hps, sess) hps.train_its, hps.test_its, hps.full_test_its = get_its(hps) # Create log dir logdir = os.path.abspath(hps.logdir) + "/" if not os.path.exists(logdir): os.mkdir(logdir) # Create model import model model = model.model(sess, hps, train_iterator, test_iterator, data_init) # Initialize visualization functions draw_samples = init_visualizations(hps, model, logdir) _print(hps) _print('Starting training. Logging to', logdir) _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg') # Train sess.graph.finalize() n_processed = 0 n_images = 0 train_time = 0.0 test_loss_best = 999999 if hvd.rank() == 0: train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__) test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__) tcurr = time.time() for epoch in range(1, hps.epochs): t = time.time() train_results = [] for it in range(hps.train_its): # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs. lr = hps.lr * min(1., n_processed / (hps.n_train * hps.epochs_warmup)) # Run a training step synchronously. _t = time.time() train_results += [model.train(lr)] if hps.verbose and hvd.rank() == 0: _print(n_processed, time.time()-_t, train_results[-1]) sys.stdout.flush() # Images seen wrt anchor resolution n_processed += hvd.size() * hps.n_batch_train # Actual images seen at current resolution n_images += hvd.size() * hps.local_batch_train train_results = np.mean(np.asarray(train_results), axis=0) dtrain = time.time() - t ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain train_time += dtrain if hvd.rank() == 0: train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int( train_time), **process_results(train_results)) if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0: test_results = [] msg = '' t = time.time() # model.polyak_swap() if epoch % hps.epochs_full_valid == 0: # Full validation run for it in range(hps.full_test_its): test_results += [model.test()] test_results = np.mean(np.asarray(test_results), axis=0) if hvd.rank() == 0: test_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, **process_results(test_results)) # Save checkpoint if test_results[0] < test_loss_best: test_loss_best = test_results[0] model.save(logdir+"model_best_loss.ckpt") msg += ' *' dtest = time.time() - t # Sample t = time.time() if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0: draw_samples(epoch) dsample = time.time() - t if hvd.rank() == 0: dcurr = time.time() - tcurr tcurr = time.time() _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format( ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg) # model.polyak_swap() if hvd.rank() == 0: _print("Finished!")
def train(sess, model, hps, logdir, visualise): _print(hps) _print('Starting training. Logging to', logdir) _print('\nepoch train [t_loss, bits_x_u, bits_x_o, bits_y, reg] ' 'test n_processed n_images (ips, dtrain, dtest, dsample, dtot), msg \n') # Train sess.graph.finalize() checkpoint_state_json_path = os.path.join(hps.restore_path, "last_checkpoint_state.json") # restore the meta of last checkpoint if hps.restore_path != '': state_dict = load_runtime_state(checkpoint_state_json_path) current_epoch = state_dict['current_epoch'] + 1 n_processed = state_dict['n_processed'] prev_train_loss = state_dict['train_loss_best'] prev_test_loss = state_dict['test_loss_best'] _print('Loaded the lastest checkpoint (epoch %d, n_p %d) from %s' % (current_epoch, n_processed, checkpoint_state_json_path)) else: n_processed = 0 current_epoch = 0 prev_test_loss = None prev_train_loss = None n_images = 0 train_time = 0.0 test_loss_best = prev_test_loss if prev_test_loss is not None else 10.0 train_loss_best = prev_train_loss if prev_train_loss is not None else 10.0 if hvd.rank() == 0: train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__) test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__) tcurr = time.time() for epoch in range(current_epoch, hps.epochs): t = time.time() train_results = [] for _ in range(hps.train_its): # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs. lr = hps.lr * min(1., n_processed / (hps.n_train * hps.epochs_warmup)) # Run a training step synchronously. _t = time.time() train_results += [model.train(lr)] if hps.verbose and hvd.rank() == 0: _print(n_processed, time.time()-_t, train_results[-1]) sys.stdout.flush() # Images seen wrt anchor resolution n_processed += hvd.size() * hps.n_batch_train # Actual images seen at current resolution n_images += hvd.size() * hps.local_batch_train train_results = np.mean(np.asarray(train_results), axis=0) if train_results[0] < train_loss_best: save_subdir = os.path.join(logdir, 'saved_models', 'best_train_loss') os.makedirs(save_subdir, exist_ok=True) train_loss_best = train_results[0] model.save(os.path.join(save_subdir, 'model_best_train_loss.ckpt')) write_runtime_state(os.path.join(save_subdir, 'last_checkpoint_state.json'), { 'current_epoch': epoch, 'n_processed': n_processed, 'train_loss_best': str(train_loss_best), 'test_loss_best': str(test_loss_best) } ) dtrain = time.time() - t ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain train_time += dtrain if hvd.rank() == 0: train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int( train_time), **process_results(train_results)) if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0: test_results = [] msg = '' t = time.time() # model.polyak_swap() if epoch % hps.epochs_full_valid == 0: # Full validation run for _ in range(hps.full_test_its): test_results += [model.test()] test_results = np.mean(np.asarray(test_results), axis=0) if hvd.rank() == 0: test_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, **process_results(test_results)) # Save checkpoint if test_results[0] < test_loss_best: save_subdir = os.path.join(logdir, 'saved_models', 'best_test_loss') os.makedirs(save_subdir, exist_ok=True) test_loss_best = test_results[0] model.save(os.path.join(save_subdir, "model_best_loss.ckpt")) write_runtime_state(os.path.join(save_subdir, 'last_checkpoint_state.json'), { 'current_epoch': epoch, 'n_processed': n_processed, 'train_loss_best': float(train_loss_best), 'test_loss_best': float(test_loss_best) } ) msg += ' *' dtest = time.time() - t # Sample t = time.time() if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0: visualise(epoch) dsample = time.time() - t if hvd.rank() == 0: dcurr = time.time() - tcurr tcurr = time.time() train_results = list2string(train_results) test_results = list2string(test_results) _print("{:<10} [{:<20}] [{:<20}] {:>10} {:>10} ({:.1f} {:.1f} {:.1f} {:.1f} {:.1f})".format( epoch, train_results,test_results, n_processed,n_images, ips, dtrain, dtest, dsample, dcurr), msg) # model.polyak_swap() if hvd.rank() == 0: _print("Finished!")