def train(args): # setting for logging if not os.path.exists(args.log): os.mkdir(args.log) logger = logging.getLogger() logging.basicConfig(level=logging.INFO) log_path = os.path.join(args.log, 'log') file_handler = logging.FileHandler(log_path) fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s') file_handler.setFormatter(fmt) logger.addHandler(file_handler) logger.info('Arguments...') for arg, val in vars(args).items(): logger.info('{:>10} -----> {}'.format(arg, val)) x, y = gen_synthetic_data(DIM, DIM_EMB, NUM) train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2) valid_x, test_x, valid_y, test_y = train_test_split(test_x, test_y, test_size=0.5) gen = Generator(DIM_EMB) dis = Discriminator(DIM_EMB) gen_opt = optimizers.Adam() dis_opt = optimizers.Adam() gen_opt.setup(gen) dis_opt.setup(dis) trainer = GANTrainer((gen, dis), (gen_opt, dis_opt), logger, (valid_x, valid_y), args.epoch) trainer.fit(train_x, train_y)
def fit(self, data): """Fit the model to the given data. Args: data(pandas.DataFrame): dataset to fit the model. Returns: None """ self.preprocessor = Preprocessor( continuous_columns=self.continuous_columns) data = self.preprocessor.fit_transform(data) self.metadata = self.preprocessor.metadata dataflow = TGANDataFlow(data, self.metadata) batch_data = BatchData(dataflow, self.batch_size) input_queue = QueueInput(batch_data) self.model = self.get_model(training=True) if self.trainer == 'GANTrainer': trainer = GANTrainer(model=self.model, input_queue=input_queue) elif self.trainer == 'SeparateGANTrainer': trainer = SeparateGANTrainer(model=self.model, input_queue=input_queue) else: raise ValueError( 'Incorrect trainer name. Use GANTrainer or SeparateGANTrainer') # trainer = SeparateGANTrainer(model=self.model, input_queue=input_queue) self.restore_path = os.path.join(self.model_dir, 'checkpoint') if os.path.isfile(self.restore_path) and self.restore_session: session_init = SaverRestore(self.restore_path) with open(os.path.join(self.log_dir, 'stats.json')) as f: starting_epoch = json.load(f)[-1]['epoch_num'] + 1 else: session_init = None starting_epoch = 1 action = 'k' if self.restore_session else None logger.set_logger_dir(self.log_dir, action=action) callbacks = [] if self.save_checkpoints: callbacks.append(ModelSaver(checkpoint_dir=self.model_dir)) trainer.train_with_defaults(callbacks=callbacks, steps_per_epoch=self.steps_per_epoch, max_epoch=self.max_epoch, session_init=session_init, starting_epoch=starting_epoch) self.prepare_sampling()
def main(config, resume): train_logger = Logger() # setup data_loader instances train_data_loader = get_instance(module_data, 'train_data_loader', config) valid_data_loader = get_instance(module_data, 'valid_data_loader', config) # build model architecture model = {} model["generator"] = get_instance(module_g_arch, 'generator_arch', config) print(model["generator"]) model["local_discriminator"] = get_instance(module_dl_arch, 'local_discriminator_arch', config) print(model["local_discriminator"]) # get function handles of loss and metrics loss = {} loss["vanilla_gan"] = torch.nn.BCELoss() loss["lsgan"] = torch.nn.MSELoss() loss["ce"] = module_loss.cross_entropy2d loss["pg"] = module_loss.PG_Loss() loss["mask_ce"] = module_loss.Masked_CrossEntropy() metrics = [getattr(module_metric, met) for met in config['metrics']] # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler optimizer = {} generator_trainable_params = filter(lambda p: p.requires_grad, model["generator"].parameters()) local_discriminator_trainable_params = filter( lambda p: p.requires_grad, model["local_discriminator"].parameters()) optimizer["generator"] = get_instance(torch.optim, 'generator_optimizer', config, generator_trainable_params) optimizer["local_discriminator"] = get_instance( torch.optim, 'discriminator_optimizer', config, local_discriminator_trainable_params) # lr_scheduler = None # get_instance(torch.optim.lr_scheduler, 'lr_scheduler', config, optimizer) trainer = GANTrainer(model, optimizer, loss, metrics, resume=resume, config=config, data_loader=train_data_loader, valid_data_loader=valid_data_loader, train_logger=train_logger) print("pretrain models") trainer.pre_train() print("training") trainer.train() evaluator = UnetEvaluator(trainer.generator, trainer.config) evaluator.evaluate()
video_transforms, is_train=False) testloader = torch.utils.data.DataLoader( testdataset, batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu, drop_last=True, shuffle=False, num_workers=int(cfg.WORKERS)) if args.eval_fid: algo = Infer(output_dir, 1.0) algo.eval_fid2(testloader, video_transforms, image_transforms) elif args.eval_fvd: algo = Infer(output_dir, 1.0) algo.eval_fvd(imageloader, storyloader, testloader, cfg.STAGE) elif args.load_ckpt != None: # For inference training result algo = Infer(output_dir, 1.0, args.load_ckpt) algo.inference(imageloader, storyloader, testloader, cfg.STAGE) else: # For training model algo = GANTrainer(output_dir, args, ratio=1.0) algo.train(imageloader, storyloader, testloader, cfg.STAGE) else: datapath = '%s/test/val_captions.t7' % (cfg.DATA_DIR) algo = GANTrainer(output_dir) algo.sample(datapath, cfg.STAGE)
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') output_dir = '../output/%s_%s_%s' % \ (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) num_gpu = len(cfg.GPU_ID.split(',')) if cfg.TRAIN.FLAG: image_transform = transforms.Compose([ transforms.RandomCrop(cfg.IMSIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = TextDataset(cfg.DATA_DIR, 'train', imsize=cfg.IMSIZE, transform=image_transform) assert dataset dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) algo = GANTrainer(output_dir) algo.train(dataloader, cfg.STAGE) else: datapath = '%s/test/val_captions_custom.t7' % (cfg.DATA_DIR) algo = GANTrainer(output_dir) algo.sample(datapath, cfg.STAGE)
vid.append(image_transform(im)) except RuntimeError as err: print(err, "/", im.shape) raise vid = torch.stack(vid).permute(1, 0, 2, 3) return vid video_transforms = functools.partial(video_transform, image_transform=image_transforms) storydataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, True) imagedataset = data.ImageDataset(dir_path, image_transforms, cfg.VIDEO_LEN, True) testdataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, False) imageloader = torch.utils.data.DataLoader( imagedataset, batch_size=cfg.TRAIN.IM_BATCH_SIZE * num_gpu, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) storyloader = torch.utils.data.DataLoader( storydataset, batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) testloader = torch.utils.data.DataLoader( testdataset, batch_size=24 * num_gpu, drop_last=True, shuffle=False, num_workers=int(cfg.WORKERS)) algo = GANTrainer(output_dir, cfg.ST_WEIGHT, test_sample_save_dir) algo.train(imageloader, storyloader, testloader)
storydataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, True) imagedataset = data.ImageDataset(dir_path, image_transforms, cfg.VIDEO_LEN, True) testdataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, False) imageloader = torch.utils.data.DataLoader( imagedataset, batch_size=cfg.TRAIN.IM_BATCH_SIZE * num_gpu, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) storyloader = torch.utils.data.DataLoader( storydataset, batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) testloader = torch.utils.data.DataLoader(testdataset, batch_size=24, drop_last=True, shuffle=False, num_workers=int(cfg.WORKERS)) algo = GANTrainer(output_dir, cfg, cfg.ST_WEIGHT, test_sample_save_dir, cfg.TENSORBOARD) algo.train(imageloader, storyloader, testloader)
def main(args): np.random.seed(0) torch.manual_seed(0) with open('config.yaml', 'r') as file: stream = file.read() config_dict = yaml.safe_load(stream) config = mapper(**config_dict) disc_model = Discriminator(input_shape=(config.data.channels, config.data.hr_height, config.data.hr_width)) gen_model = GeneratorResNet() feature_extractor_model = FeatureExtractor() plt.ion() if config.distributed: disc_model.to(device) disc_model = nn.parallel.DistributedDataParallel(disc_model) gen_model.to(device) gen_model = nn.parallel.DistributedDataParallel(gen_model) feature_extractor_model.to(device) feature_extractor_model = nn.parallel.DistributedDataParallel( feature_extractor_model) elif config.gpu: # disc_model = nn.DataParallel(disc_model).to(device) # gen_model = nn.DataParallel(gen_model).to(device) # feature_extractor_model = nn.DataParallel(feature_extractor_model).to(device) disc_model = disc_model.to(device) gen_model = gen_model.to(device) feature_extractor_model = feature_extractor_model.to(device) else: return train_dataset = ImageDataset(config.data.path, hr_shape=(config.data.hr_height, config.data.hr_width), lr_shape=(config.data.lr_height, config.data.lr_width)) test_dataset = ImageDataset(config.data.path, hr_shape=(config.data.hr_height, config.data.hr_width), lr_shape=(config.data.lr_height, config.data.lr_width)) if config.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.data.batch_size, shuffle=config.data.shuffle, num_workers=config.data.workers, pin_memory=config.data.pin_memory, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.data.batch_size, shuffle=config.data.shuffle, num_workers=config.data.workers, pin_memory=config.data.pin_memory) if args.train: # trainer settings trainer = GANTrainer(config.train, train_loader, (disc_model, gen_model, feature_extractor_model)) criterion = nn.MSELoss().to(device) disc_optimizer = torch.optim.Adam(disc_model.parameters(), config.train.hyperparameters.lr) gen_optimizer = torch.optim.Adam(gen_model.parameters(), config.train.hyperparameters.lr) fe_optimizer = torch.optim.Adam(feature_extractor_model.parameters(), config.train.hyperparameters.lr) trainer.setCriterion(criterion) trainer.setDiscOptimizer(disc_optimizer) trainer.setGenOptimizer(gen_optimizer) trainer.setFEOptimizer(fe_optimizer) # evaluator settings evaluator = GANEvaluator( config.evaluate, val_loader, (disc_model, gen_model, feature_extractor_model)) # optimizer = torch.optim.Adam(disc_model.parameters(), lr=config.evaluate.hyperparameters.lr, # weight_decay=config.evaluate.hyperparameters.weight_decay) evaluator.setCriterion(criterion) if args.test: pass # Turn on benchmark if the input sizes don't vary # It is used to find best way to run models on your machine cudnn.benchmark = True start_epoch = 0 best_precision = 0 # optionally resume from a checkpoint if config.train.resume: [start_epoch, best_precision] = trainer.load_saved_checkpoint(checkpoint=None) # change value to test.hyperparameters on testing for epoch in range(start_epoch, config.train.hyperparameters.total_epochs): if config.distributed: train_sampler.set_epoch(epoch) if args.train: trainer.adjust_learning_rate(epoch) trainer.train(epoch) prec1 = evaluator.evaluate(epoch) if args.test: pass # remember best prec@1 and save checkpoint if args.train: is_best = prec1 > best_precision best_precision = max(prec1, best_precision) trainer.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': disc_model.state_dict(), 'best_precision': best_precision, 'optimizer': optimizer.state_dict(), }, is_best, checkpoint=None)
def main(args): # preparation if not os.path.exists(args.exp_dir): os.makedirs(args.exp_dir) config_logging(os.path.join(args.exp_dir, "%s.log" % args.exp_name)) log.info("Experiment %s" % (args.exp_name)) log.info("Receive config %s" % (args.__str__())) log.info("Start creating tasks") pretrain_task = [get_task(taskname, args) for taskname in args.pretrain_task] finetune_tasks = [get_task(taskname, args) for taskname in args.finetune_tasks] log.info("Start loading data") if args.image_pretrain_obj != "none" or args.view_pretrain_obj != "none": for task in pretrain_task: task.load_data() for task in finetune_tasks: task.load_data() log.info("Start creating models") if len(pretrain_task): if args.image_pretrain_obj != "none": image_ssl_model = get_model("image_ssl", args) log.info("Loaded image ssl model") if args.view_pretrain_obj != "none": view_ssl_model = get_model("view_ssl", args) log.info("Loaded view ssl model") if args.finetune_obj != "none": sup_model = get_model("sup", args) log.info("Loaded supervised model") #if args.load_ckpt != "none": # load_model(model, pretrain_complete_ckpt) # pretrain if len(pretrain_task): if args.image_pretrain_obj != "none": image_ssl_model.to(args.device) pretrain = Trainer("pretrain", image_ssl_model, pretrain_task[0], args) pretrain.train() image_pretrain_complete_ckpt = os.path.join( args.exp_dir, "image_pretrain_%s_complete.pth" % pretrain_task[0].name ) save_model(image_pretrain_complete_ckpt, image_ssl_model) else: if args.imagessl_load_ckpt: image_pretrain_complete_ckpt = args.imagessl_load_ckpt if args.view_pretrain_obj != "none": view_ssl_model.to(args.device) pretrain = Trainer("pretrain", view_ssl_model, pretrain_task[0], args) pretrain.train() view_pretrain_complete_ckpt = os.path.join( args.exp_dir, "view_pretrain_%s_complete.pth" % pretrain_task[0].name ) save_model(view_pretrain_complete_ckpt, view_ssl_model) else: if args.viewssl_load_ckpt: view_pretrain_complete_ckpt = args.viewssl_load_ckpt # finetune and test for task in finetune_tasks: if args.imagessl_load_ckpt is not "none": pretrained_dict = torch.load(image_pretrain_complete_ckpt,map_location=torch.device('cpu')) model_dict = sup_model.state_dict() tdict = model_dict.copy() # print(sup_model.image_network.parameters()) # print((sup_model.image_network[1].weight.data)) # wtv = sup_model.image_network[0].weight.data # print(tdict.items()==model_dict.items()) # print(type(tdict),type(model_dict)) # print(model_dict.keys()) # print("\n\n\n") pretrained_dict = {k.replace("patch","image"): v for k, v in pretrained_dict.items() if k.replace("patch","image") in model_dict} # print(pretrained_dict.keys()) # print("\n\n\n") model_dict.update(pretrained_dict) sup_model.load_state_dict(model_dict) # print(type(tdict),type(model_dict)) # print(sup_model.image_network[1].weight.data) # print((tdict.items()==model_dict.items()).all()) if "adv" in args.finetune_obj: # print(type(sup_model)) sup_model["generator"].to(args.device) sup_model["discriminator"].to(args.device) finetune = GANTrainer("finetune", sup_model, task, args) else: sup_model.to(args.device) finetune = Trainer("finetune", sup_model, task, args) finetune.train() finetune.eval("test") if "adv" in args.finetune_obj: finetune_generator_complete_ckpt = os.path.join( args.exp_dir, "finetune_%s_generator_complete.pth" % task.name ) save_model(finetune_generator_complete_ckpt, sup_model["generator"]) finetune_discriminator_complete_ckpt = os.path.join( args.exp_dir, "finetune_%s_discriminator_complete.pth" % task.name ) save_model(finetune_discriminator_complete_ckpt, sup_model["discriminator"]) else: finetune_complete_ckpt = os.path.join( args.exp_dir, "finetune_%s_complete.pth" % task.name ) save_model(finetune_complete_ckpt, sup_model) # evaluate # TODO: evaluate result on test split, write prediction for leaderboard submission (for dataset # without test labels) log.info("Done") return
def main(gpu_id, data_dir, manual_seed, cuda, train_flag, image_size, batch_size, workers, stage, dataset_name, config_name, max_epoch, snapshot_interval, net_g, net_d, z_dim, generator_lr, discriminator_lr, lr_decay_epoch, coef_kl, stage1_g, embedding_type, condition_dim, df_dim, gf_dim, res_num, text_dim, regularizer): if manual_seed is None: manual_seed = random.randint(1, 10000) random.seed(manual_seed) torch.manual_seed(manual_seed) if cuda: torch.cuda.manual_seed_all(manual_seed) now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') output_dir = '../output/%s_%s_%s' % (dataset_name, config_name, timestamp) num_gpu = len(gpu_id.split(',')) if train_flag: image_transform = transforms.Compose([ transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = TextDataset(data_dir, 'train', imsize=image_size, transform=image_transform) assert dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size * num_gpu, drop_last=True, shuffle=True, num_workers=int(workers)) algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id, batch_size, train_flag, net_g, net_d, cuda, stage1_g, z_dim, generator_lr, discriminator_lr, lr_decay_epoch, coef_kl, regularizer) algo.train(dataloader, stage, text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num) elif dataset_name == 'birds' and train_flag is False: image_transform = transforms.Compose([ transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = TextDataset(data_dir, 'train', imsize=image_size, transform=image_transform) assert dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size * num_gpu, drop_last=True, shuffle=True, num_workers=int(workers)) algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id, batch_size, train_flag, net_g, net_d, cuda, stage1_g, z_dim, generator_lr, discriminator_lr, lr_decay_epoch, coef_kl, regularizer) algo.birds_eval(dataloader, stage) else: datapath = '%s/test/val_captions.t7' % (data_dir) algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id, batch_size, train_flag, net_g, net_d, cuda, stage1_g, z_dim, generator_lr, discriminator_lr, lr_decay_epoch, coef_kl, regularizer) algo.sample(datapath, stage)
mkdir_p(os.path.join(output_dir, 'model_reserve')) else: split = 'train' batch_size = cfg.TRAIN.BATCH_SIZE * num_gpu shuffle_flag = True Dataset = choose_dataset(cfg.DATASET_NAME) dataset = Dataset(cfg.DATA_DIR, split, imsize=cfg.IMSIZE) # Note the batchsize setting is here dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, drop_last=True, shuffle=shuffle_flag, num_workers=int(cfg.WORKERS)) # Initialize the main class which includes the training and evaluation algo = GANTrainer(output_dir, cfg_path=args.cfg_file) if cfg.TRAIN.FLAG: algo.train(dataloader) elif args.eval: date_str = datetime.datetime.now().strftime('%b-%d-%I%M%p-%G') if args.FID_eval: '''Do FID evaluations.''' f = open(os.path.join(output_dir, 'all_FID_eval.txt'), 'a') for net_G_name in net_G_names: cfg.NET_G = net_G_name algo.sample(dataloader, eval_name='eval', eval_num=args.eval_num) fid_score_now = \
dataset = TextImageDataset(data_dir=cfg.DATA_DIR, ann_file=cfg.ANN_FILE, imsize=cfg.IMSIZE, emb_model=cfg.EMB_MODEL, transform=image_transform, vocab_file=vocab) dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, collate_fn=collate_fn, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) algo = GANTrainer(output_dir, cap_model, vocab, eval_utils, my_resnet, dataset.word2idx, dataset.emb, dataset.idx2word, vocab_cap=vocab_cap, eval_kwargs=vars(opt)) algo.train(dataloader, cfg.STAGE) else: datapath = '%s/test/val_captions.t7' % (cfg.DATA_DIR) algo = GANTrainer(output_dir) algo.sample(datapath, cfg.STAGE)