class SolverMultigpu(object): def __init__(self, train=True, common_params=None, solver_params=None, net_params=None, dataset_params=None): if common_params: self.gpus = [ int(device) for device in str(common_params['gpus']).split(',') ] self.image_size = int(common_params['image_size']) self.height = self.image_size self.width = self.image_size self.batch_size = int(common_params['batch_size']) / len(self.gpus) if solver_params: self.learning_rate = float(solver_params['learning_rate']) self.moment = float(solver_params['moment']) self.max_steps = int(solver_params['max_iterators']) self.train_dir = str(solver_params['train_dir']) self.lr_decay = float(solver_params['lr_decay']) self.decay_steps = int(solver_params['decay_steps']) self.tower_name = 'Tower' self.num_gpus = len(self.gpus) self.train = train self.net = Net(train=train, common_params=common_params, net_params=net_params) self.dataset = DataSet(common_params=common_params, dataset_params=dataset_params) self.placeholders = [] def construct_cpu_graph(self, scope): data_l = tf.placeholder(tf.float32, (self.batch_size, self.height, self.width, 1)) gt_ab_313 = tf.placeholder( tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4), 313)) prior_boost_nongray = tf.placeholder( tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4), 1)) conv8_313 = self.net.inference(data_l) self.net.loss(scope, conv8_313, prior_boost_nongray, gt_ab_313) def construct_tower_gpu(self, scope): data_l = tf.placeholder(tf.float32, (self.batch_size, self.height, self.width, 1)) gt_ab_313 = tf.placeholder( tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4), 313)) prior_boost_nongray = tf.placeholder( tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4), 1)) self.placeholders.append(data_l) self.placeholders.append(gt_ab_313) self.placeholders.append(prior_boost_nongray) conv8_313 = self.net.inference(data_l) new_loss, g_loss = self.net.loss(scope, conv8_313, prior_boost_nongray, gt_ab_313) tf.summary.scalar('new_loss', new_loss) tf.summary.scalar('total_loss', g_loss) return new_loss, g_loss def average_gradients(self, tower_grads): """Calculate the average gradient for each shared variable across all towers. Note that this function provides a synchronization point across all towers. Args: tower_grads: List of lists of (gradient, variable) tuples. The outer list is over individual gradients. The inner list is over the gradient calculation for each tower. Returns: List of pairs of (gradient, variable) where the gradient has been averaged across all towers. """ average_grads = [] for grad_and_vars in zip(*tower_grads): # Note that each grad_and_vars looks like the following: # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) grads = [] for g, _ in grad_and_vars: # Add 0 dimension to the gradients to represent the tower. expanded_g = tf.expand_dims(g, 0) # Append on a 'tower' dimension which we will average over below. grads.append(expanded_g) # Average over the 'tower' dimension. grad = tf.concat(0, grads) grad = tf.reduce_mean(grad, 0) # Keep in mind that the Variables are redundant because they are shared # across towers. So .. we will just return the first tower's pointer to # the Variable. v = grad_and_vars[0][1] grad_and_var = (grad, v) average_grads.append(grad_and_var) return average_grads def train_model(self): with tf.Graph().as_default(), tf.device('/cpu:0'): self.global_step = tf.get_variable( 'global_step', [], initializer=tf.constant_initializer(0), trainable=False) learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_steps, self.lr_decay, staircase=True) opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta2=0.99) with tf.name_scope('cpu_model') as scope: self.construct_cpu_graph(scope) tf.get_variable_scope().reuse_variables() tower_grads = [] for i in self.gpus: with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % (self.tower_name, i)) as scope: new_loss, self.total_loss = self.construct_tower_gpu( scope) self.summaries = tf.get_collection( tf.GraphKeys.SUMMARIES, scope) grads = opt.compute_gradients(new_loss) tower_grads.append(grads) grads = self.average_gradients(tower_grads) self.summaries.append( tf.summary.scalar('learning_rate', learning_rate)) for grad, var in grads: if grad is not None: self.summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) apply_gradient_op = opt.apply_gradients( grads, global_step=self.global_step) for var in tf.trainable_variables(): self.summaries.append(tf.summary.histogram(var.op.name, var)) variable_averages = tf.train.ExponentialMovingAverage( 0.999, self.global_step) variables_averages_op = variable_averages.apply( tf.trainable_variables()) train_op = tf.group(apply_gradient_op, variables_averages_op) saver = tf.train.Saver(write_version=1) saver1 = tf.train.Saver() summary_op = tf.summary.merge(self.summaries) init = tf.global_variables_initializer() config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(init) #saver1.restore(sess, self.pretrain_model) #nilboy summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph) for step in xrange(self.max_steps): start_time = time.time() t1 = time.time() feed_dict = {} np_feeds = [] data_l, gt_ab_313, prior_boost_nongray = self.dataset.batch() for i in range(self.num_gpus): np_feeds.append( data_l[self.batch_size * i:self.batch_size * (i + 1), :, :, :]) np_feeds.append( gt_ab_313[self.batch_size * i:self.batch_size * (i + 1), :, :, :]) np_feeds.append(prior_boost_nongray[self.batch_size * i:self.batch_size * (i + 1), :, :, :]) for i in range(len(self.placeholders)): feed_dict[self.placeholders[i]] = np_feeds[i] t2 = time.time() _, loss_value = sess.run([train_op, self.total_loss], feed_dict=feed_dict) duration = time.time() - start_time t3 = time.time() print('io: ' + str(t2 - t1) + '; compute: ' + str(t3 - t2)) assert not np.isnan( loss_value), 'Model diverged with loss = NaN' if step % 1 == 0: num_examples_per_step = self.batch_size * self.num_gpus examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / self.num_gpus format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) if step % 10 == 0: summary_str = sess.run(summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 == 0: checkpoint_path = os.path.join(self.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
class Solver(object): def __init__(self, train=True, common_params=None, solver_params=None, net_params=None, dataset_params=None): if common_params: self.device = common_params['device'] self.image_size = int(common_params['image_size']) self.height = self.image_size self.width = self.image_size self.batch_size = int(common_params['batch_size']) self.num_gpus = 1 # end_to_end: if use end_to_end attention model or Richard Zhang's model self.end_to_end = False if common_params['end_to_end']=='False' else True # use_attention_in_cost: if use attention to weight loss in the cost function self.use_attention_in_cost = False if common_params['use_attention_in_cost']=='False' else True if solver_params: self.learning_rate = float(solver_params['learning_rate']) self.moment = float(solver_params['moment']) self.max_steps = int(solver_params['max_iterators']) self.train_dir = str(solver_params['train_dir']) self.lr_decay = float(solver_params['lr_decay']) self.decay_steps = int(solver_params['decay_steps']) self.common_params = common_params self.net_params = net_params self.train = train self.dataset = DataSet(common_params=common_params, dataset_params=dataset_params) def construct_graph_for_student(self): with tf.device(self.device): self.training_flag = tf.placeholder(tf.bool) self.res_hm1 = tf.placeholder(tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4))) self.res_hm2 = tf.placeholder(tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4))) self.res_hm3 = tf.placeholder(tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4))) self.data_l = tf.placeholder(tf.float32, (self.batch_size, self.height, self.width, 1)) self.gt_ab_313 = tf.placeholder(tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4), 313)) self.prior_color_weight_nongray = tf.placeholder(tf.float32, (self.batch_size, int(self.height / 4), int(self.width / 4), 1)) if self.end_to_end == True: self.net = Net_att(train=self.training_flag, common_params=self.common_params, net_params=self.net_params) else: self.net = Net(train=self.training_flag, common_params=self.common_params, net_params=self.net_params) # self.net = DenseNet(train=self.training_flag, common_params=self.common_params, net_params=self.net_params) self.conv8_313 = self.net.inference(self.data_l) new_loss, g_loss = self.net.loss(self.conv8_313, self.prior_color_weight_nongray, self.gt_ab_313, self.res_hm1, self.res_hm2, self.res_hm3, self.use_attention_in_cost) tf.summary.scalar('new_loss', new_loss) tf.summary.scalar('total_loss', g_loss) return new_loss, g_loss def construct_graph_for_teacher(self): with tf.device(self.device): inputs = tf.placeholder(tf.float32, shape=(None, 224, 224, 3)) _, end_points = slim_vgg.vgg_16(inputs) # heatmap tensors hm1 = end_points['hm1'] hm2 = end_points['hm2'] hm3 = end_points['hm3'] return inputs, hm1, hm2, hm3 # Normalize attention heat map def process_attention(self, attention_hm, size1, size2=64): eps = 1e-5 res_hm = attention_hm.reshape(self.batch_size, size1**2) # center heat map centered_res_hm = res_hm - res_hm.mean(axis=1).reshape((self.batch_size,1)) # divide by stdev denom_res_hm = np.sqrt((centered_res_hm**2).sum(axis=1)/(size1*size1) + eps).reshape((self.batch_size,1)) res_hm = centered_res_hm / denom_res_hm # reshape res_hm = res_hm.reshape((self.batch_size, size1, size1)) # resize to 64 x 64 res_hm = np.concatenate([cv2.resize(res_hm[i], (size2, size2))[None, :, :] for i in range(self.batch_size)], axis=0) return res_hm def train_model(self): with tf.device(self.device): # Student # Construct graph new_loss, self.total_loss = self.construct_graph_for_student() # Initialize and configure optimizer self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_steps, self.lr_decay, staircase=True) opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta2=0.99) # Compute gradient, moving average of weights and update weights grads = opt.compute_gradients(new_loss) apply_gradient_op = opt.apply_gradients(grads, global_step=self.global_step) variable_averages = tf.train.ExponentialMovingAverage( 0.999, self.global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) train_op = tf.group(apply_gradient_op, variables_averages_op) # Record values into summary self.summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope='colorization') self.summaries.append(tf.summary.scalar('learning_rate', learning_rate)) for grad, var in grads: if grad is not None: self.summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad)) for var in tf.trainable_variables(): if var is not None: self.summaries.append(tf.summary.histogram(var.op.name, var)) summary_op = tf.summary.merge(self.summaries) # Initialize and configure student and teacher sessions config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess_teacher = tf.Session(config=config) # Student: load/create model saver_student = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='colorization')) ckpt_student = tf.train.get_checkpoint_state('models/model.ckpt') if ckpt_student and tf.train.checkpoint_exists(ckpt_student.model_checkpoint_path): saver_student.restore(sess, ckpt_student.model_checkpoint_path) else: sess.run(tf.global_variables_initializer()) # Teacher: load model inputs, hm1, hm2, hm3 = self.construct_graph_for_teacher() saver_teacher = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='vgg_16')) saver_teacher.restore(sess_teacher, 'models/vgg16.ckpt') # Student: Initialize summary writer summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph) for step in range(self.max_steps): start_time = time.time() # Get input data images, data_l, gt_ab_313, prior_color_weight_nongray = self.dataset.batch() res_hm1 = np.zeros((self.batch_size, 64, 64)) res_hm2 = np.zeros((self.batch_size, 64, 64)) res_hm3 = np.zeros((self.batch_size, 64, 64)) # Extract attention when the end-to-end structure is not used if self.use_attention_in_cost: # Teacher: Forward pass to grab/process heat map res_pics = np.concatenate([cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)[None, :, :, :] for img in images], axis=0) attention_hm1, attention_hm2, attention_hm3 = sess_teacher.run((hm1, hm2, hm3), feed_dict={inputs: res_pics}) res_hm1 = self.process_attention(attention_hm1, 56, 64) res_hm2 = self.process_attention(attention_hm2, 28, 64) res_hm3 = self.process_attention(attention_hm3, 7, 64) # Student: Optimize objective for colorization feed_d={self.training_flag:self.train, self.data_l:data_l, self.gt_ab_313:gt_ab_313, self.prior_color_weight_nongray:prior_color_weight_nongray, self.res_hm1:res_hm1, self.res_hm2:res_hm2, self.res_hm3:res_hm3} _, loss_value = sess.run([train_op, self.total_loss], feed_dict=feed_d) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' # Print training info periodically. if step % 1 == 0: num_examples_per_step = self.batch_size * self.num_gpus examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / self.num_gpus format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) # Record progress periodically. if step % 20 == 0: summary_str = sess.run(summary_op, feed_dict=feed_d) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 100 == 0: checkpoint_path = os.path.join(self.train_dir, 'model.ckpt') saver_student.save(sess, checkpoint_path)
class Solver(object): def __init__(self, train=True, common_params=None, solver_params=None, net_params=None, dataset_params=None): if common_params: self.device_id = int(common_params['gpus']) self.image_size = int(common_params['image_size']) self.height = self.image_size self.width = self.image_size self.batch_size = int(common_params['batch_size']) self.num_gpus = 1 if solver_params: self.learning_rate = float(solver_params['learning_rate']) self.moment = float(solver_params['moment']) self.max_steps = int(solver_params['max_iterators']) self.train_dir = str(solver_params['train_dir']) self.lr_decay = float(solver_params['lr_decay']) self.decay_steps = int(solver_params['decay_steps']) self.train = train self.net = Net(train=train, common_params=common_params, net_params=net_params) self.dataset = DataSet(common_params=common_params, dataset_params=dataset_params) def construct_graph(self, scope): with tf.device('/gpu:' + str(self.device_id)): self.data_l = tf.placeholder( tf.float32, (self.batch_size, self.height, self.width, 1)) self.gt_ab_313 = tf.placeholder( tf.float32, (self.batch_size, int( self.height / 4), int(self.width / 4), 313)) self.prior_boost_nongray = tf.placeholder( tf.float32, (self.batch_size, int( self.height / 4), int(self.width / 4), 1)) self.conv8_313 = self.net.inference(self.data_l) new_loss, g_loss = self.net.loss(scope, self.conv8_313, self.prior_boost_nongray, self.gt_ab_313) tf.summary.scalar('new_loss', new_loss) tf.summary.scalar('total_loss', g_loss) return new_loss, g_loss def train_model(self): with tf.device('/gpu:' + str(self.device_id)): self.global_step = tf.get_variable( 'global_step', [], initializer=tf.constant_initializer(0), trainable=False) learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_steps, self.lr_decay, staircase=True) opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta2=0.99) with tf.name_scope('gpu') as scope: new_loss, self.total_loss = self.construct_graph(scope) self.summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) grads = opt.compute_gradients(new_loss) self.summaries.append( tf.summary.scalar('learning_rate', learning_rate)) for grad, var in grads: if grad is not None: self.summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) apply_gradient_op = opt.apply_gradients( grads, global_step=self.global_step) for var in tf.trainable_variables(): self.summaries.append(tf.summary.histogram(var.op.name, var)) variable_averages = tf.train.ExponentialMovingAverage( 0.999, self.global_step) variables_averages_op = variable_averages.apply( tf.trainable_variables()) train_op = tf.group(apply_gradient_op, variables_averages_op) saver = tf.train.Saver(write_version=1) saver1 = tf.train.Saver() summary_op = tf.summary.merge(self.summaries) init = tf.global_variables_initializer() config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(init) #saver1.restore(sess, './models/model.ckpt') #nilboy summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph) for step in xrange(self.max_steps): start_time = time.time() t1 = time.time() data_l, gt_ab_313, prior_boost_nongray = self.dataset.batch() t2 = time.time() _, loss_value = sess.run( [train_op, self.total_loss], feed_dict={ self.data_l: data_l, self.gt_ab_313: gt_ab_313, self.prior_boost_nongray: prior_boost_nongray }) duration = time.time() - start_time t3 = time.time() print('io: ' + str(t2 - t1) + '; compute: ' + str(t3 - t2)) assert not np.isnan( loss_value), 'Model diverged with loss = NaN' if step % 1 == 0: num_examples_per_step = self.batch_size * self.num_gpus examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / self.num_gpus format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) if step % 10 == 0: summary_str = sess.run(summary_op, feed_dict={ self.data_l: data_l, self.gt_ab_313: gt_ab_313, self.prior_boost_nongray: prior_boost_nongray }) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 == 0: checkpoint_path = os.path.join(self.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
class Solver(object): def __init__(self, train=True, common_params=None, solver_params=None, net_params=None, dataset_params=None): if common_params: self.device_id = int(common_params['gpus']) self.image_size = int(common_params['image_size']) self.height = self.image_size self.width = self.image_size self.batch_size = int(common_params['batch_size']) self.num_gpus = 1 self.d_repeat = int(common_params['d_repeat']) self.g_repeat = int(common_params['g_repeat']) self.ckpt = common_params[ 'ckpt'] if 'ckpt' in common_params else None self.init_ckpt = common_params[ 'init_ckpt'] if 'init_ckpt' in common_params else None self.restore_opt = True if common_params[ 'restore_opt'] == '1' else False self.gan = True if common_params['gan'] == '1' else False self.prior_boost = True if common_params[ 'prior_boost'] == '1' else False self.corr = True if common_params[ 'correspondence'] == '1' else False if self.corr: print('Discriminator has correspondence.') else: print('Discriminator has no correspondence.') if self.gan: print('Using GAN.') else: print('Not using GAN.') if self.prior_boost: print('Using prior boost.') else: print('Not using prior boost.') if solver_params: self.learning_rate = float(solver_params['learning_rate']) self.D_learning_rate = float(solver_params['d_learning_rate']) print("Learning rate G: {0} D: {1}".format(self.learning_rate, self.D_learning_rate)) # self.moment = float(solver_params['moment']) self.max_steps = int(solver_params['max_iterators']) self.train_dir = str(solver_params['train_dir']) self.lr_decay = float(solver_params['lr_decay']) self.decay_steps = int(solver_params['decay_steps']) self.moment = float(solver_params['moment']) self.train = train self.net = Net(train=train, common_params=common_params, net_params=net_params) self.dataset = DataSet(common_params=common_params, dataset_params=dataset_params) self.val_dataset = DataSet(common_params=common_params, dataset_params=dataset_params, training=False) print("Solver initialization done.") def construct_graph(self, scope): with tf.device('/gpu:' + str(self.device_id)): self.data_l = tf.placeholder( tf.float32, (self.batch_size, self.height, self.width, 1)) self.gt_ab_313 = tf.placeholder( tf.float32, (self.batch_size, int( self.height / 4), int(self.width / 4), 313)) self.prior_boost_nongray = tf.placeholder( tf.float32, (self.batch_size, int( self.height / 4), int(self.width / 4), 1)) conv8_313 = self.net.inference(self.data_l) new_loss, g_loss, wd_loss, rb_loss = self.net.loss( scope, conv8_313, self.prior_boost_nongray, self.gt_ab_313) tf.summary.scalar('new_loss', new_loss) tf.summary.scalar('rb_loss', rb_loss) tf.summary.scalar('wd_loss', wd_loss) tf.summary.scalar('total_loss', g_loss) return (new_loss, g_loss, rb_loss) def lr_decay_on_plateau(self, sess, curr_loss, threshold=3): if curr_loss >= self.prev_loss: self.increasing_count += 1 if self.increasing_count == threshold: # Decay. old_lr = self.learning_rate_tensor.value() sess.run(self.learning_rate_tensor.assign(old_lr * 0.1)) print('Learning rate decayed to {0}.'.format( old_lr.eval(session=sess))) self.increasing_count = 0 else: self.increasing_count = 0 self.prev_loss = curr_loss def train_model(self): with tf.device('/gpu:' + str(self.device_id)): self.global_step = tf.get_variable( 'global_step', [], initializer=tf.constant_initializer(0), trainable=False) self.learning_rate_tensor = tf.train.exponential_decay( self.learning_rate, self.global_step, self.decay_steps, self.lr_decay, staircase=True) with tf.name_scope('gpu') as scope: self.new_loss, self.total_loss, self.rb_loss = self.construct_graph( scope) self.summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) self.summaries.append( tf.summary.scalar('learning_rate', self.learning_rate_tensor)) opt = tf.train.AdamOptimizer( learning_rate=self.learning_rate_tensor, beta1=self.moment, beta2=0.99) G_vars = tf.trainable_variables(scope='G') grads = opt.compute_gradients(self.new_loss, var_list=G_vars) total_param = 0 for var in tf.global_variables(scope='G'): print(var) total_param += np.prod(var.get_shape()) print('Total params: {}.'.format(total_param)) apply_gradient_op = opt.apply_gradients( grads, global_step=self.global_step) variable_averages = tf.train.ExponentialMovingAverage( 0.999, self.global_step) variables_averages_op = variable_averages.apply(G_vars) train_op = tf.group(apply_gradient_op, variables_averages_op) savable_vars = tf.global_variables() saver = tf.train.Saver(savable_vars, write_version=tf.train.SaverDef.V2, max_to_keep=5, keep_checkpoint_every_n_hours=1) summary_op = tf.summary.merge(self.summaries) init = tf.global_variables_initializer() config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) print("Session configured.") if self.ckpt is not None: # sess.run(self.learning_rate_tensor.initializer) if self.restore_opt: saver.restore(sess, self.ckpt) else: sess.run(init) init_saver = tf.train.Saver(G_vars + T_vars + D_vars + [self.global_step]) init_saver.restore(sess, self.ckpt) print(self.ckpt + " restored.") start_step = sess.run(self.global_step) start_step -= int(start_step % 10) # start_step = 230000 # sess.run(self.global_step.assign(start_step)) print("Global step: {}".format(start_step)) else: sess.run(init) print("Initialized.") start_step = 0 if self.init_ckpt is not None: init_saver = tf.train.Saver(tf.global_variables(scope='G')) init_saver.restore(sess, self.init_ckpt) print('Init generator with {}.'.format(self.init_ckpt)) summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph) start_time = time.time() start_step = int(start_step) for step in xrange(start_step, self.max_steps, self.g_repeat): # Generator training. for _ in xrange(self.g_repeat): data_l, gt_ab_313, prior_boost_nongray, _ = self.dataset.batch( ) sess.run( [train_op], feed_dict={ self.data_l: data_l, self.gt_ab_313: gt_ab_313, self.prior_boost_nongray: prior_boost_nongray }) if step % _LOG_FREQ < self.g_repeat: duration = time.time() - start_time num_examples_per_step = self.batch_size * self.num_gpus * _LOG_FREQ examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / (self.num_gpus * _LOG_FREQ) loss_value, new_loss_value, rb_loss_value = sess.run( [self.total_loss, self.new_loss, self.rb_loss], feed_dict={ self.data_l: data_l, self.gt_ab_313: gt_ab_313, self.prior_boost_nongray: prior_boost_nongray }) format_str = ( '%s: step %d, G loss = %.2f, new loss = %.2f rb loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') # assert not np.isnan(loss_value), 'Model diverged with loss = NaN' # assert not np.isnan(adv_loss_value), 'Adversarial diverged with loss = NaN' # assert not np.isnan(D_loss_value), 'Discriminator diverged with loss = NaN' print(format_str % (datetime.now(), step, loss_value, new_loss_value, rb_loss_value, examples_per_sec, sec_per_batch)) start_time = time.time() if step % 100 < self.g_repeat: summary_str = sess.run(summary_op, feed_dict={ self.data_l: data_l, self.gt_ab_313: gt_ab_313, self.prior_boost_nongray: prior_boost_nongray }) eval_loss = 0.0 eval_loss_rb = 0.0 eval_iters = 30 for _ in xrange(eval_iters): val_data_l, val_gt_ab_313, val_prior_boost_nongray, _ = self.val_dataset.batch( ) loss_value, rb_loss_value = sess.run( [self.total_loss, self.rb_loss], feed_dict={ self.data_l: val_data_l, self.gt_ab_313: val_gt_ab_313, self.prior_boost_nongray: val_prior_boost_nongray }) eval_loss += loss_value eval_loss_rb += rb_loss_value eval_loss /= eval_iters eval_loss_rb /= eval_iters eval_loss_sum = scalar_summary('eval/loss', eval_loss) eval_loss_rb_sum = scalar_summary('eval/loss_rb', eval_loss_rb) summary_writer.add_summary(eval_loss_sum, step) summary_writer.add_summary(eval_loss_rb_sum, step) print( 'Evaluation at step {0}: loss {1}, rebalanced loss {2}.' .format(step, eval_loss, eval_loss_rb)) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 < self.g_repeat: checkpoint_path = os.path.join(self.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)