def train_with_ctr(self): cfg = self.cfg if cfg.dataset == 'codraw': self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0_e.pt')) elif cfg.dataset == 'iclevr': self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0_e.pt')) iteration_counter = 0 for epoch in tqdm(range(self.cfg.epochs), ascii=True): if cfg.dataset == 'codraw': self.dataset.shuffle() for batch in self.dataloader: if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0: torch.cuda.empty_cache() evaluator = Evaluator.factory(self.cfg, self.visualizer, self.logger) res = evaluator.evaluate(iteration_counter) print('\nIter %d:' % (iteration_counter)) print(res) self.logger.write_res(iteration_counter, res) del evaluator iteration_counter += 1 self.model.train_batch_with_ctr(batch, epoch, iteration_counter, self.visualizer, self.logger)
def train(self): iteration_counter = 0 for epoch in tqdm(range(self.cfg.epochs), ascii=True): if cfg.dataset == 'codraw': self.dataset.shuffle() for batch in self.dataloader: if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0: torch.cuda.empty_cache() evaluator = Evaluator.factory(self.cfg, self.visualizer, self.logger) res = evaluator.evaluate(iteration_counter) print('\nIter %d:' % (iteration_counter)) print(res) self.logger.write_res(iteration_counter, res) del evaluator iteration_counter += 1 self.model.train_batch(batch, epoch, iteration_counter, self.visualizer, self.logger)
def train(self): iteration_counter = 0 #print("Total number of training data: {}".format(len(self.dataset))) num_batches = len(self.dataloader) total_iterations = num_batches * self.cfg.epochs current_batch_time = 0 # Record the time it takes to process one batch for epoch in range(self.cfg.epochs): if cfg.dataset in ['codraw', 'codrawDialog', 'gandraw', "gandraw_clean", "gandraw_64", "gandraw_64_DA"]: self.dataset.shuffle() for batch in self.dataloader: #iteration_counter += 1 current_batch_start = time.time() self.model.train_batch(batch, epoch, iteration_counter, self.visualizer, self.logger, total_iters=total_iterations, current_batch_t=current_batch_time ) current_batch_time = time.time() - current_batch_start print("batch_time is: {}".format(current_batch_time)) if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0: print("Run Evaluation") torch.cuda.empty_cache() #evaluator = TellerEvaluator(self.cfg, self.visualizer, self.logger) evaluator = Evaluator.factory(self.cfg, self.visualizer, self.logger) evaluator.evaluate(iteration_counter, self.model) del evaluator iteration_counter += 1
def train(self): iteration_counter = 0 for epoch in range(self.cfg.epochs): if cfg.dataset == 'codraw': self.dataset.shuffle() for batch in self.dataloader: if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0: torch.cuda.empty_cache() evaluator = Evaluator.factory(self.cfg, self.visualizer, self.logger) evaluator.evaluate(iteration_counter) del evaluator iteration_counter += 1 self.model.train_batch(batch, epoch, iteration_counter, self.visualizer, self.logger)
def train(self): iteration_counter = 0 # Last Iteration best_saved_iteration = 0 #highest iteration that we have a saved model best_scene_sim_score = 0 #print("Total number of training data: {}".format(len(self.dataset))) num_batches = len(self.dataloader) total_iterations = num_batches * self.cfg.epochs current_batch_time = 0 # Record the time it takes to process one batch #print("total iteration is: {}".format(total_iterations)) visualize_images = [] for epoch in range(self.cfg.epochs): if cfg.dataset in [ 'codraw', 'gandraw', "gandraw_64", "gandraw_64_DA" ]: self.dataset.shuffle() for batch in self.dataloader: if cfg.gan_type == "recurrent_gan": self.model.train_batch(batch, epoch, iteration_counter, self.visualizer, self.logger) else: current_batch_start = time.time() self.model.train_batch(batch, epoch, iteration_counter, self.visualizer, self.logger, total_iters=total_iterations, current_batch_t=current_batch_time) current_batch_time = time.time() - current_batch_start print("batch_time is: {}".format(current_batch_time)) if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0: torch.cuda.empty_cache() evaluator = Evaluator.factory( self.cfg, self.visualizer, self.logger, visualize_images=visualize_images) metrics_report = evaluator.evaluate(iteration_counter) print( "evaluation results for iter: {} on validation data: \n" .format(iteration_counter)) for key, value in metrics_report.items(): print("{metric_name}: {metric_value}; \n".format( metric_name=key, metric_value=value)) #udpate the best scene_sim_score if metrics_report['scene_sim_score'] > best_scene_sim_score: best_scene_sim_score = metrics_report[ 'scene_sim_score'] best_saved_iteration = iteration_counter del evaluator iteration_counter += 1 # if iteration_counter > 1: #Evaluate on the test data torch.cuda.empty_cache() evaluator = Evaluator.factory(self.cfg, self.visualizer, self.logger, visualize_images=visualize_images) metrics_report = evaluator.evaluate(best_saved_iteration, use_test=True) print("best iteration is: {}".format(best_saved_iteration)) print("evaluation results for iter: {} on test data: \n".format( best_saved_iteration)) for key, value in metrics_report.items(): print("{metric_name}: {metric_value}; \n".format( metric_name=key, metric_value=value)) del evaluator
from geneva.data.datasets import DATASETS from geneva.evaluation.evaluate import Evaluator from geneva.utils.config import keys, parse_config from geneva.utils.visualize import VisdomPlotter from geneva.models.models import MODELS from geneva.data import gandraw_dataset import time if __name__ == '__main__': config_file = "example_args/gandraw_args.json" with open(config_file, 'r') as f: cfg = json.load(f) cfg = easydict.EasyDict(cfg) best_iteration = 1500 #Manually define #Initialize the evaluator visualizer = VisdomPlotter(env_name=cfg.exp_name, server=cfg.vis_server) logger = None evaluator = Evaluator.factory(cfg, visualizer, logger, visualize_images=[]) metric_report = evaluator.evaluate(best_iteration, use_test=True) print("evaluation results for iter: {} on test data: \n".format( best_iteration)) for key, value in metric_report.items(): print("{metric_name}: {metric_value}; \n".format(metric_name=key, metric_value=value))