def __init__(self): self.tool = Tool(hparams.sens_num, hparams.key_len, hparams.sen_len, hparams.poem_len, 0.0) self.tool.load_dic(hparams.vocab_path, hparams.ivocab_path) vocab_size = self.tool.get_vocab_size() print("vocabulary size: %d" % (vocab_size)) PAD_ID = self.tool.get_PAD_ID() B_ID = self.tool.get_B_ID() assert vocab_size > 0 and PAD_ID >= 0 and B_ID >= 0 self.hps = hparams._replace(vocab_size=vocab_size, pad_idx=PAD_ID, bos_idx=B_ID) # load model model = MixPoetAUS(self.hps) # load trained model utils.restore_checkpoint(self.hps.model_dir, device, model) self.model = model.to(device) self.model.eval() #utils.print_parameter_list(self.model) # load poetry filter print("loading poetry filter...") self.filter = PoetryFilter(self.tool.get_vocab(), self.tool.get_ivocab(), self.hps.data_dir) print("--------------------------")
def main(device=torch.device('cuda:0')): # CLI arguments parser = arg.ArgumentParser( description='We all know what we are doing. Fighting!') parser.add_argument("--datasize", "-d", default="small", type=str, help="data size you want to use, small, medium, total") # Parsing args = parser.parse_args() # Data loaders datasize = args.datasize pathname = "data/nyu.zip" tr_loader, va_loader, te_loader = getTrainingValidationTestingData( datasize, pathname, batch_size=config("unet.batch_size")) # Model model = Net() # define loss function # criterion = torch.nn.L1Loss() # Attempts to restore the latest checkpoint if exists print("Loading unet...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config("unet.checkpoint")) acc, loss = utils.evaluate_model(model, te_loader, device) # axes = util.make_training_plot() print(f'Test Accuracy:{acc}') print(f'Test Loss:{loss}')
def main(device=torch.device('cuda:0')): # CLI arguments parser = arg.ArgumentParser( description='We all know what we are doing. Fighting!') parser.add_argument("--datasize", "-d", default="small", type=str, help="data size you want to use, small, medium, total") # Parsing args = parser.parse_args() # Data loaders datasize = args.datasize pathname = "data/nyu.zip" tr_loader, va_loader, te_loader = getTrainingValidationTestingData( datasize, pathname, batch_size=config("unet.batch_size")) # Model #model = Net() #model = Dense121() model = Dense169() model = model.to(device) # define loss function # criterion = torch.nn.L1Loss() # Attempts to restore the latest checkpoint if exists print("Loading unet...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config("unet.checkpoint")) acc, loss = utils.evaluate_model(model, te_loader, device) # axes = util.make_training_plot() print(f'Test Error:{acc}') print(f'Test Loss:{loss}') # Get Test Images img_list = glob("examples/" + "*.png") # Set model to eval mode model.eval() model = model.to(device) # Begin testing loop print("Begin Test Loop ...") for idx, img_name in enumerate(img_list): img = load_images([img_name]) img = torch.Tensor(img).float().to(device) print("Processing {}, Tensor Shape: {}".format(img_name, img.shape)) with torch.no_grad(): preds = model(img).squeeze(0) output = colorize(preds.data) output = output.transpose((1, 2, 0)) cv2.imwrite(img_name.split(".")[0] + "_result.png", output) print("Processing {} done.".format(img_name))
def main(device=torch.device('cuda:0')): # CLI arguments parser = arg.ArgumentParser( description='We all know what we are doing. Fighting!') parser.add_argument("--datasize", "-d", default="small", type=str, help="data size you want to use, small, medium, total") # Parsing args = parser.parse_args() # Data loaders # TODO: ####### Enter the model selection here! ##### modelSelection = input( 'Please input the type of model to be used(res50,dense121,dense169,mob_v2,mob):' ) datasize = args.datasize filename = "nyu_new.zip" pathname = f"data/{filename}" csv = "data/nyu_csv.zip" te_loader = getTestingData(datasize, csv, pathname, batch_size=config(modelSelection + ".batch_size")) # Model if modelSelection.lower() == 'res50': model = Res50() elif modelSelection.lower() == 'dense121': model = Dense121() elif modelSelection.lower() == 'mob_v2': model = Mob_v2() elif modelSelection.lower() == 'dense169': model = Dense169() elif modelSelection.lower() == 'mob': model = Net() elif modelSelection.lower() == 'squeeze': model = Squeeze() else: assert False, 'Wrong type of model selection string!' model = model.to(device) # define loss function # criterion = torch.nn.L1Loss() # Attempts to restore the latest checkpoint if exists print(f"Loading {mdoelSelection}...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config(modelSelection + ".checkpoint")) acc, loss = utils.evaluate_model(model, te_loader, device, test=True) # axes = util.make_training_plot() print(f'Test Error:{acc}') print(f'Test Loss:{loss}')
def train(mixpoet, tool, hps): last_epoch = utils.restore_checkpoint(hps.model_dir, device, mixpoet) if last_epoch is not None: print ("checkpoint exsits! directly recover!") else: print ("checkpoint not exsits! train from scratch!") mix_trainer = MixTrainer(hps) mix_trainer.train(mixpoet, tool)
def main(device=torch.device('cuda:0')): # Model modelSelection = input( 'Please input the type of model to be used(res50,dense121,dense169,dense161,mob_v2,mob):' ) if modelSelection.lower() == 'res50': model = Res50() elif modelSelection.lower() == 'dense121': model = Dense121() elif modelSelection.lower() == 'dense161': model = Dense161() elif modelSelection.lower() == 'mob_v2': model = Mob_v2() elif modelSelection.lower() == 'dense169': model = Dense169() elif modelSelection.lower() == 'mob': model = Net() elif modelSelection.lower() == 'squeeze': model = Squeeze() else: assert False, 'Wrong type of model selection string!' model = model.to(device) # Attempts to restore the latest checkpoint if exists print("Loading unet...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config(modelSelection + ".checkpoint")) # Get Test Images img_list = glob("examples/" + "*.png") # Set model to eval mode model.eval() model = model.to(device) # Begin testing loop print("Begin Test Loop ...") for idx, img_name in enumerate(img_list): img = load_images([img_name]) img = torch.Tensor(img).float().to(device) print("Processing {}, Tensor Shape: {}".format(img_name, img.shape)) with torch.no_grad(): preds = model(img).squeeze(0) output = colorize(preds.data) output = output.transpose((1, 2, 0)) cv2.imwrite( img_name.split(".")[0] + "_" + modelSelection + "_result.png", output) print("Processing {} done.".format(img_name))
def train(wm_model, tool, hps, specified_device): last_epoch = utils.restore_checkpoint( hps.model_dir, specified_device, wm_model) if last_epoch is not None: print ("checkpoint exsits! directly recover!") else: print ("checkpoint not exsits! train from scratch!") wm_trainer = WMTrainer(hps, specified_device) wm_trainer.train(wm_model, tool)
def __init__(self, hps, device): self.tool = Tool(hps.sens_num, hps.sen_len, hps.key_len, hps.topic_slots, 0.0) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() print("vocabulary size: %d" % (vocab_size)) PAD_ID = self.tool.get_PAD_ID() B_ID = self.tool.get_B_ID() assert vocab_size > 0 and PAD_ID >= 0 and B_ID >= 0 self.hps = hps._replace(vocab_size=vocab_size, pad_idx=PAD_ID, bos_idx=B_ID) self.device = device # load model model = WorkingMemoryModel(self.hps, device) # load trained model utils.restore_checkpoint(self.hps.model_dir, device, model) self.model = model.to(device) self.model.eval() null_idxes = self.tool.load_function_tokens(self.hps.data_dir + "fchars.txt").to( self.device) self.model.set_null_idxes(null_idxes) self.model.set_tau(hps.min_tau) # load poetry filter print("loading poetry filter...") self.filter = PoetryFilter(self.tool.get_vocab(), self.tool.get_ivocab(), self.hps.data_dir) self.visual_tool = Visualization(hps.topic_slots, hps.his_mem_slots, "../log/") print("--------------------------")
def evaluate_checkpoint_on(restore_checkpoint, dataset_cfg, _run, model_update_cfg={}): model_cfg, _, epoch = utils.restore_checkpoint(restore_checkpoint, model_cfg=model_update_cfg, map_location='cpu') #model_cfg['backbone']['output_dim'] = 256 dataloaders = dataloader_builder.build(dataset_cfg) model = model_builder.build(model_cfg) # TODO needs to be from dataset if 'seg_class_mapping' in model_cfg: mapping = model_cfg['seg_class_mapping'] else: mapping = None model.seg_mapping = mapping model = torch.nn.DataParallel(model, device_ids=_run.config['device_id']) model = model.cuda() return evaluate(dataloaders, model, epoch, keep=True)
optimizer = torch.optim.Adam([{ 'params': net.encoder.parameters(), 'weight_decay': 1e-2 }, { 'params': net.decoder.parameters(), 'weight_decay': 0 }], lr=1e-4, eps=1e-6) scheduler = StepLR(optimizer, step_size=1, gamma=0.8) test_loss = 999999 epoch_loss = 999999 if restore_check is True: net, optimizer, scheduler, epoch_loss, test_loss = restore_checkpoint( net, optimizer, scheduler, masked, recent=True, inception=False) for param in optimizer.param_groups: lr = param['lr'] trainNet(net, batch_size=8, n_epochs=100, learning_rate=lr, last_epoch_loss=epoch_loss, last_loss=test_loss, optimizer=optimizer, scheduler=scheduler, save=True)
def main(device, tr_loader, va_loader, te_loader, modelSelection): """Train CNN and show training plots.""" # CLI arguments # parser = arg.ArgumentParser(description='We all know what we are doing. Fighting!') # parser.add_argument("--datasize", "-d", default="small", type=str, # help="data size you want to use, small, medium, total") # Parsing # args = parser.parse_args() # Data loaders # datasize = args.datasize # Model if modelSelection.lower() == 'res50': model = Res50() elif modelSelection.lower() == 'dense121': model = Dense121() elif modelSelection.lower() == 'mobv2': model = Mob_v2() elif modelSelection.lower() == 'dense169': model = Dense169() elif modelSelection.lower() == 'mob': model = Net() elif modelSelection.lower() == 'squeeze': model = Squeeze() else: assert False, 'Wrong type of model selection string!' # Model # model = Net() # model = Squeeze() model = model.to(device) # TODO: define loss function, and optimizer learning_rate = utils.config(modelSelection + ".learning_rate") criterion = DepthLoss(0.1).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) number_of_epoches = 10 # # Attempts to restore the latest checkpoint if exists print("Loading unet...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config(modelSelection + ".checkpoint")) running_va_loss = [] if 'va_loss' not in stats else stats['va_loss'] running_va_acc = [] if 'va_err' not in stats else stats['va_err'] running_tr_loss = [] if 'tr_loss' not in stats else stats['tr_loss'] running_tr_acc = [] if 'tr_err' not in stats else stats['tr_err'] tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device) acc, loss = utils.evaluate_model(model, va_loader, device) running_va_acc.append(acc) running_va_loss.append(loss) running_tr_acc.append(tr_acc) running_tr_loss.append(tr_loss) stats = { 'va_err': running_va_acc, 'va_loss': running_va_loss, 'tr_err': running_tr_acc, 'tr_loss': running_tr_loss, # 'num_of_epoch': 0 } # Loop over the entire dataset multiple times # for epoch in range(start_epoch, config('cnn.num_epochs')): epoch = start_epoch # while curr_patience < patience: while epoch < number_of_epoches: # Train model utils.train_epoch(device, tr_loader, model, criterion, optimizer) # Save checkpoint utils.save_checkpoint(model, epoch + 1, utils.config(modelSelection + ".checkpoint"), stats) # Evaluate model tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device) va_acc, va_loss = utils.evaluate_model(model, va_loader, device) running_va_acc.append(va_acc) running_va_loss.append(va_loss) running_tr_acc.append(tr_acc) running_tr_loss.append(tr_loss) epoch += 1 print("Finished Training") utils.make_plot(running_tr_loss, running_tr_acc, running_va_loss, running_va_acc)
def play(args): env = create_mario_env(args.env_name, ACTIONS[args.move_set]) observation_space = env.observation_space.shape[0] action_space = env.action_space.n model = ActorCritic(observation_space, action_space) checkpoint_file = \ f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar" checkpoint = restore_checkpoint(checkpoint_file) assert args.env_name == checkpoint['env'], \ "This checkpoint is for different environment: {checkpoint['env']}" args.model_id = checkpoint['id'] print(f"Environment: {args.env_name}") print(f" Agent: {args.model_id}") model.load_state_dict(checkpoint['model_state_dict']) state = env.reset() state = torch.from_numpy(state) reward_sum = 0 done = True episode_length = 0 start_time = time.time() for step in count(): episode_length += 1 # shared model sync if done: cx = torch.zeros(1, 512) hx = torch.zeros(1, 512) else: cx = cx.data hx = hx.data with torch.no_grad(): value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx))) prob = F.softmax(logit, dim=-1) action = prob.max(-1, keepdim=True)[1] action_idx = action.item() action_out = ACTIONS[args.move_set][action_idx] state, reward, done, info = env.step(action_idx) reward_sum += reward print( f"{emojize(':mushroom:')} World {info['world']}-{info['stage']} | {emojize(':video_game:')}: [ {' + '.join(action_out):^13s} ] | ", end='\r', ) env.render() if done: t = time.time() - start_time print( f"{emojize(':mushroom:')} World {info['world']}-{info['stage']} |" + \ f" {emojize(':video_game:')}: [ {' + '.join(action_out):^13s} ] | " + \ f"ID: {args.model_id}, " + \ f"Time: {time.strftime('%H:%M:%S', time.gmtime(t)):^9s}, " + \ f"Reward: {reward_sum: 10.2f}, " + \ f"Progress: {(info['x_pos'] / 3225) * 100: 3.2f}%", end='\r', flush=True, ) reward_sum = 0 episode_length = 0 time.sleep(args.reset_delay) state = env.reset() state = torch.from_numpy(state)
def main(args): print(f" Session ID: {args.uuid}") # logging log_dir = f'logs/{args.env_name}/{args.model_id}/{args.uuid}/' args_logger = setup_logger('args', log_dir, f'args.log') env_logger = setup_logger('env', log_dir, f'env.log') if args.debug: debug.packages() os.environ['OMP_NUM_THREADS'] = "1" if torch.cuda.is_available(): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" devices = ",".join([str(i) for i in range(torch.cuda.device_count())]) os.environ["CUDA_VISIBLE_DEVICES"] = devices args_logger.info(vars(args)) env_logger.info(vars(os.environ)) env = create_atari_environment(args.env_name) shared_model = ActorCritic(env.observation_space.shape[0], env.action_space.n) if torch.cuda.is_available(): shared_model = shared_model.cuda() shared_model.share_memory() optimizer = SharedAdam(shared_model.parameters(), lr=args.lr) optimizer.share_memory() if args.load_model: # TODO Load model before initializing optimizer checkpoint_file = f"{args.env_name}/{args.model_id}_{args.algorithm}_params.tar" checkpoint = restore_checkpoint(checkpoint_file) assert args.env_name == checkpoint['env'], \ "Checkpoint is for different environment" args.model_id = checkpoint['id'] args.start_step = checkpoint['step'] print("Loading model from checkpoint...") print(f"Environment: {args.env_name}") print(f" Agent: {args.model_id}") print(f" Start: Step {args.start_step}") shared_model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: print(f"Environment: {args.env_name}") print(f" Agent: {args.model_id}") torch.manual_seed(args.seed) print( FontColor.BLUE + \ f"CPUs: {mp.cpu_count(): 3d} | " + \ f"GPUs: {None if not torch.cuda.is_available() else torch.cuda.device_count()}" + \ FontColor.END ) processes = [] counter = mp.Value('i', 0) lock = mp.Lock() # Queue training processes num_processes = args.num_processes no_sample = args.non_sample # count of non-sampling processes if args.num_processes > 1: num_processes = args.num_processes - 1 samplers = num_processes - no_sample for rank in range(0, num_processes): device = 'cpu' if torch.cuda.is_available(): device = 0 # TODO: Need to move to distributed to handle multigpu if rank < samplers: # random action p = mp.Process( target=train, args=(rank, args, shared_model, counter, lock, optimizer, device), ) else: # best action p = mp.Process( target=train, args=(rank, args, shared_model, counter, lock, optimizer, device, False), ) p.start() time.sleep(1.) processes.append(p) # Queue test process p = mp.Process(target=test, args=(args.num_processes, args, shared_model, counter, 0)) p.start() processes.append(p) for p in processes: p.join()
for i, data in enumerate(dataset, 0): inputs, labels, mask = data inputs, labels, mask = inputs.to(device), labels.to( device), mask.to(device) outputs = net(inputs.float()) loss_size = loss(outputs, labels, mask) total_loss[index] = total_loss[index] + loss_size.data if debug: if (i + 1) % (len(dataset) // 3 + 1) == 0: print("{:d}%".format(int(i / len(dataset) * 100))) if (i + 1) % (len(dataset) // 10 + 1) == 0: draw_debug(inputs, labels, mask, outputs, name='test') total_loss[index] = total_loss[index] / len(dataset) print('Test results:\n\trel: {:.6f}'.format(total_loss[0])) return total_loss[0] if __name__ == '__main__': net = InceptionResNetV2().to(device) net, _, _, _, _ = utils.restore_checkpoint(net, None, None, inception=True, masked=True, recent=False) net.eval() print(testNet(net))
def main(device=torch.device('cuda:0')): # CLI arguments parser = arg.ArgumentParser( description='We all know what we are doing. Fighting!') parser.add_argument("--datasize", "-d", default="small", type=str, help="data size you want to use, small, medium, total") # Parsing args = parser.parse_args() # Data loaders datasize = args.datasize pathname = "data/nyu.zip" # Model modelSelection = input( 'Please input the type of model to be used(res50,dense121,dense169,mob_v2,mob):' ) # Model if modelSelection.lower() == 'res50': model = Res50() elif modelSelection.lower() == 'dense121': model = Dense121() elif modelSelection.lower() == 'mob_v2': model = Mob_v2() elif modelSelection.lower() == 'dense169': model = Dense169() elif modelSelection.lower() == 'mob': model = Net() elif modelSelection.lower() == 'squeeze': model = Squeeze() else: assert False, 'Wrong type of model selection string!' model = model.to(device) # Attempts to restore the latest checkpoint if exists print("Loading unet...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config(modelSelection + ".checkpoint")) # Get Test Images img_list = glob("examples/" + "*.png") # Set model to eval mode model.eval() model = model.to(device) # Begin testing loop print("Begin Test Loop ...") for idx, img_name in enumerate(img_list): img = load_images([img_name]) img = torch.Tensor(img).float().to(device) print("Processing {}, Tensor Shape: {}".format(img_name, img.shape)) with torch.no_grad(): preds = model(img).squeeze(0) output = colorize(preds.data) output = output.transpose((1, 2, 0)) cv2.imwrite( img_name.split(".")[0] + "_" + modelSelection + "_result.png", output) print("Processing {} done.".format(img_name))
config.eval.batch_size = batch_size random_seed = 0 #@param {"type": "integer"} sigmas = mutils.get_sigmas(config) scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) score_model = mutils.create_model(config) optimizer = get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(step=0, optimizer=optimizer, model=score_model, ema=ema) state = restore_checkpoint(ckpt_filename, state, config.device) ema.copy_to(score_model.parameters()) #@title Visualization code def image_grid(x): size = config.data.image_size channels = config.data.num_channels img = x.reshape(-1, size, size, channels) w = int(np.sqrt(img.shape[0])) img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels)) return img def show_samples(x): x = x.permute(0, 2, 3, 1).detach().cpu().numpy() img = image_grid(x)
def run_train(dataloader_cfg, model_cfg, scheduler_cfg, optimizer_cfg, loss_cfg, validation_cfg, checkpoint_frequency, restore_checkpoint, max_epochs, _run): # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! torch.backends.cudnn.benchmark = True exit_handler = ExitHandler() device = _run.config['device'] device_id = _run.config['device_id'] # during training just one dataloader dataloader = dataloader_builder.build(dataloader_cfg)[0] epoch = 0 if restore_checkpoint is not None: model_cfg, optimizer_cfg, epoch = utils.restore_checkpoint( restore_checkpoint, model_cfg, optimizer_cfg) def overwrite(to_overwrite, dic): to_overwrite.update(dic) return to_overwrite # some models depend on dataset, for example num_joints model_cfg = overwrite(dataloader.dataset.info, model_cfg) model = model_builder.build(model_cfg) loss_cfg['model'] = model loss = loss_builder.build(loss_cfg) loss = loss.to(device) parameters = list(model.parameters()) + list(loss.parameters()) optimizer = optimizer_builder.build(optimizer_cfg, parameters) lr_scheduler = scheduler_builder.build(scheduler_cfg, optimizer, epoch) if validation_cfg is None: validation_dataloaders = None else: validation_dataloaders = dataloader_builder.build(validation_cfg) keep = False file_logger = log.get_file_logger() logger = log.get_logger() model = torch.nn.DataParallel(model, device_ids=device_id) model.cuda() model = model.train() trained_models = [] exit_handler.register(file_logger.save_checkpoint, model, optimizer, "atexit", model_cfg) start_training_time = time.time() end = time.time() while epoch < max_epochs: epoch += 1 lr_scheduler.step() logger.info("Starting Epoch %d/%d", epoch, max_epochs) len_batch = len(dataloader) acc_time = 0 for batch_id, data in enumerate(dataloader): optimizer.zero_grad() endpoints = model(data, model.module.endpoints) logger.debug("datasets %s", list(data['split_info'].keys())) data.update(endpoints) # threoretically losses could also be caluclated distributed. losses = loss(endpoints, data) loss_mean = torch.mean(losses) loss_mean.backward() optimizer.step() acc_time += time.time() - end end = time.time() report_after_batch(_run=_run, logger=logger, batch_id=batch_id, batch_len=len_batch, acc_time=acc_time, loss_mean=loss_mean, max_mem=torch.cuda.max_memory_allocated()) if epoch % checkpoint_frequency == 0: path = file_logger.save_checkpoint(model, optimizer, epoch, model_cfg) trained_models.append(path) report_after_epoch(_run=_run, epoch=epoch, max_epoch=max_epochs) if validation_dataloaders is not None and \ epoch % checkpoint_frequency == 0: model.eval() # Lets cuDNN benchmark conv implementations and choose the fastest. # Only good if sizes stay the same within the main loop! # not the case for segmentation torch.backends.cudnn.benchmark = False score = evaluate(validation_dataloaders, model, epoch, keep=keep) logger.info(score) log_score(score, _run, prefix="val_", step=epoch) torch.backends.cudnn.benchmark = True model.train() report_after_training(_run=_run, max_epoch=max_epochs, total_time=time.time() - start_training_time) path = file_logger.save_checkpoint(model, optimizer, epoch, model_cfg) if path: trained_models.append(path) file_logger.close() # TODO get best performing val model evaluate_last = _run.config['training'].get('evaluate_last', 1) if len(trained_models) < evaluate_last: logger.info("Only saved %d models (evaluate_last=%d)", len(trained_models), evaluate_last) return trained_models[-evaluate_last:]
def main(device, tr_loader, va_loader, te_loader, modelSelection): """Train CNN and show training plots.""" # Model if modelSelection.lower() == 'res50': model = Res50() elif modelSelection.lower() == 'dense121': model = Dense121() elif modelSelection.lower() == 'dense161': model = Dense161() elif modelSelection.lower() == 'mobv2': model = Mob_v2() elif modelSelection.lower() == 'dense169': model = Dense169() elif modelSelection.lower() == 'mob': model = Net() elif modelSelection.lower() == 'squeeze': model = Squeeze() else: assert False, 'Wrong type of model selection string!' model = model.to(device) # TODO: define loss function, and optimizer learning_rate = utils.config(modelSelection + ".learning_rate") criterion = DepthLoss(0.1).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) number_of_epoches = 10 # # Attempts to restore the latest checkpoint if exists print("Loading unet...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config(modelSelection + ".checkpoint")) running_va_loss = [] if 'va_loss' not in stats else stats['va_loss'] running_va_acc = [] if 'va_err' not in stats else stats['va_err'] running_tr_loss = [] if 'tr_loss' not in stats else stats['tr_loss'] running_tr_acc = [] if 'tr_err' not in stats else stats['tr_err'] tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device) acc, loss = utils.evaluate_model(model, va_loader, device) running_va_acc.append(acc) running_va_loss.append(loss) running_tr_acc.append(tr_acc) running_tr_loss.append(tr_loss) stats = { 'va_err': running_va_acc, 'va_loss': running_va_loss, 'tr_err': running_tr_acc, 'tr_loss': running_tr_loss, } # Loop over the entire dataset multiple times # for epoch in range(start_epoch, config('cnn.num_epochs')): epoch = start_epoch # while curr_patience < patience: while epoch < number_of_epoches: # Train model utils.train_epoch(device, tr_loader, model, criterion, optimizer) # Save checkpoint utils.save_checkpoint(model, epoch + 1, utils.config(modelSelection + ".checkpoint"), stats) # Evaluate model tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device) va_acc, va_loss = utils.evaluate_model(model, va_loader, device) running_va_acc.append(va_acc) running_va_loss.append(va_loss) running_tr_acc.append(tr_acc) running_tr_loss.append(tr_loss) epoch += 1 print("Finished Training") utils.make_plot(running_tr_loss, running_tr_acc, running_va_loss, running_va_acc)
def evaluate(config, workdir, eval_folder="eval"): """Evaluate trained models. Args: config: Configuration to use. workdir: Working directory for checkpoints. eval_folder: The subfolder for storing evaluation results. Default to "eval". """ # Create directory to eval_folder eval_dir = os.path.join(workdir, eval_folder) tf.io.gfile.makedirs(eval_dir) # Build data pipeline train_ds, eval_ds, _ = datasets.get_dataset( config, uniform_dequantization=config.data.uniform_dequantization, evaluation=True) # Create data normalizer and its inverse scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) # Initialize model score_model = mutils.create_model(config) optimizer = losses.get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) checkpoint_dir = os.path.join(workdir, "checkpoints") # Setup SDEs if config.training.sde.lower() == 'vpsde': sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'subvpsde': sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'vesde': sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 else: raise NotImplementedError(f"SDE {config.training.sde} unknown.") # Create the one-step evaluation function when loss computation is enabled if config.eval.enable_loss: optimize_fn = losses.optimization_manager(config) continuous = config.training.continuous likelihood_weighting = config.training.likelihood_weighting reduce_mean = config.training.reduce_mean eval_step = losses.get_step_fn( sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting) # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset( config, uniform_dequantization=True, evaluation=True) if config.eval.bpd_dataset.lower() == 'train': ds_bpd = train_ds_bpd bpd_num_repeats = 1 elif config.eval.bpd_dataset.lower() == 'test': # Go over the dataset 5 times when computing likelihood on the test dataset ds_bpd = eval_ds_bpd bpd_num_repeats = 5 else: raise ValueError( f"No bpd dataset {config.eval.bpd_dataset} recognized.") # Build the likelihood computation function when likelihood is enabled if config.eval.enable_bpd: likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler) # Build the sampling function when sampling is enabled if config.eval.enable_sampling: sampling_shape = (config.eval.batch_size, config.data.num_channels, config.data.image_size, config.data.image_size) sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps) # Use inceptionV3 for images with resolution higher than 256. inceptionv3 = config.data.image_size >= 256 inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3) begin_ckpt = config.eval.begin_ckpt logging.info("begin checkpoint: %d" % (begin_ckpt, )) for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1): # Wait if the target checkpoint doesn't exist yet waiting_message_printed = False ckpt_filename = os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(ckpt)) while not tf.io.gfile.exists(ckpt_filename): if not waiting_message_printed: logging.warning("Waiting for the arrival of checkpoint_%d" % (ckpt, )) waiting_message_printed = True time.sleep(60) # Wait for 2 additional mins in case the file exists but is not ready for reading ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{ckpt}.pth') try: state = restore_checkpoint(ckpt_path, state, device=config.device) except: time.sleep(60) try: state = restore_checkpoint(ckpt_path, state, device=config.device) except: time.sleep(120) state = restore_checkpoint(ckpt_path, state, device=config.device) ema.copy_to(score_model.parameters()) # Compute the loss function on the full evaluation dataset if loss computation is enabled if config.eval.enable_loss: all_losses = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for i, batch in enumerate(eval_iter): eval_batch = torch.from_numpy(batch['image']._numpy()).to( config.device).float() eval_batch = eval_batch.permute(0, 3, 1, 2) eval_batch = scaler(eval_batch) eval_loss = eval_step(state, eval_batch) all_losses.append(eval_loss.item()) if (i + 1) % 1000 == 0: logging.info("Finished %dth step loss evaluation" % (i + 1)) # Save loss values to disk or Google Cloud Storage all_losses = np.asarray(all_losses) with tf.io.gfile.GFile( os.path.join(eval_dir, f"ckpt_{ckpt}_loss.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, all_losses=all_losses, mean_loss=all_losses.mean()) fout.write(io_buffer.getvalue()) # Compute log-likelihoods (bits/dim) if enabled if config.eval.enable_bpd: bpds = [] for repeat in range(bpd_num_repeats): bpd_iter = iter(ds_bpd) # pytype: disable=wrong-arg-types for batch_id in range(len(ds_bpd)): batch = next(bpd_iter) eval_batch = torch.from_numpy(batch['image']._numpy()).to( config.device).float() eval_batch = eval_batch.permute(0, 3, 1, 2) eval_batch = scaler(eval_batch) bpd = likelihood_fn(score_model, eval_batch)[0] bpd = bpd.detach().cpu().numpy().reshape(-1) bpds.extend(bpd) logging.info( "ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f" % (ckpt, repeat, batch_id, np.mean(np.asarray(bpds)))) bpd_round_id = batch_id + len(ds_bpd) * repeat # Save bits/dim to disk or Google Cloud Storage with tf.io.gfile.GFile( os.path.join( eval_dir, f"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz" ), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, bpd) fout.write(io_buffer.getvalue()) # Generate samples and compute IS/FID/KID when enabled if config.eval.enable_sampling: num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1 for r in range(num_sampling_rounds): logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r)) # Directory to save samples. Different for each host to avoid writing conflicts this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}") tf.io.gfile.makedirs(this_sample_dir) samples, n = sampling_fn(score_model) samples = np.clip( samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8) samples = samples.reshape( (-1, config.data.image_size, config.data.image_size, config.data.num_channels)) # Write samples to disk or Google Cloud Storage with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, samples=samples) fout.write(io_buffer.getvalue()) # Force garbage collection before calling TensorFlow code for Inception network gc.collect() latents = evaluation.run_inception_distributed( samples, inception_model, inceptionv3=inceptionv3) # Force garbage collection again before returning to JAX code gc.collect() # Save latent represents of the Inception network to disk or Google Cloud Storage with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, pool_3=latents["pool_3"], logits=latents["logits"]) fout.write(io_buffer.getvalue()) # Compute inception scores, FIDs and KIDs. # Load all statistics that have been previously computed and saved for each host all_logits = [] all_pools = [] this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}") stats = tf.io.gfile.glob( os.path.join(this_sample_dir, "statistics_*.npz")) for stat_file in stats: with tf.io.gfile.GFile(stat_file, "rb") as fin: stat = np.load(fin) if not inceptionv3: all_logits.append(stat["logits"]) all_pools.append(stat["pool_3"]) if not inceptionv3: all_logits = np.concatenate(all_logits, axis=0)[:config.eval.num_samples] all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples] # Load pre-computed dataset statistics. data_stats = evaluation.load_dataset_stats(config) data_pools = data_stats["pool_3"] # Compute FID/KID/IS on all samples together. if not inceptionv3: inception_score = tfgan.eval.classifier_score_from_logits( all_logits) else: inception_score = -1 fid = tfgan.eval.frechet_classifier_distance_from_activations( data_pools, all_pools) # Hack to get tfgan KID work for eager execution. tf_data_pools = tf.convert_to_tensor(data_pools) tf_all_pools = tf.convert_to_tensor(all_pools) kid = tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, tf_all_pools).numpy() del tf_data_pools, tf_all_pools logging.info( "ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e" % (ckpt, inception_score, fid, kid)) with tf.io.gfile.GFile( os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, IS=inception_score, fid=fid, kid=kid) f.write(io_buffer.getvalue())
def main(device=torch.device('cuda:0')): # CLI arguments parser = arg.ArgumentParser( description='We all know what we are doing. Fighting!') parser.add_argument("--datasize", "-d", default="small", type=str, help="data size you want to use, small, medium, total") # Parsing args = parser.parse_args() # Data loaders datasize = args.datasize pathname = "data/nyu.zip" tr_loader, va_loader, te_loader = getTrainingValidationTestingData( datasize, pathname, batch_size=config("unet.batch_size")) # Model model = Net() # TODO: define loss function, and optimizer learning_rate = utils.config("unet.learning_rate") criterion = DepthLoss(0.1) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) number_of_epoches = 10 # # print("Number of float-valued parameters:", util.count_parameters(model)) # Attempts to restore the latest checkpoint if exists print("Loading unet...") model, start_epoch, stats = utils.restore_checkpoint( model, utils.config("unet.checkpoint")) # axes = utils.make_training_plot() # Evaluate the randomly initialized model # evaluate_epoch( # axes, tr_loader, va_loader, te_loader, model, criterion, start_epoch, stats # ) # loss = criterion() # initial val loss for early stopping # prev_val_loss = stats[0][1] running_va_loss = [] running_va_acc = [] running_tr_loss = [] running_tr_acc = [] # TODO: define patience for early stopping # patience = 1 # curr_patience = 0 # tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device) acc, loss = utils.evaluate_model(model, va_loader, device) running_va_acc.append(acc) running_va_loss.append(loss) running_tr_acc.append(tr_acc) running_tr_loss.append(tr_loss) # Loop over the entire dataset multiple times # for epoch in range(start_epoch, config('cnn.num_epochs')): epoch = start_epoch # while curr_patience < patience: while epoch < number_of_epoches: # Train model utils.train_epoch(tr_loader, model, criterion, optimizer) tr_acc, tr_loss = utils.evaluate_model(model, tr_loader, device) va_acc, va_loss = utils.evaluate_model(model, va_loader, device) running_va_acc.append(va_acc) running_va_loss.append(va_loss) running_tr_acc.append(tr_acc) running_tr_loss.append(tr_loss) # Evaluate model # evaluate_epoch( # axes, tr_loader, va_loader, te_loader, model, criterion, epoch + 1, stats # ) # Save model parameters utils.save_checkpoint(model, epoch + 1, utils.config("unet.checkpoint"), stats) # update early stopping parameters """ curr_patience, prev_val_loss = early_stopping( stats, curr_patience, prev_val_loss ) """ epoch += 1 print("Finished Training") # Save figure and keep plot open # utils.save_training_plot() # utils.hold_training_plot() utils.make_plot(running_tr_loss, running_tr_acc, running_va_loss, running_va_acc)
def train(config, workdir): """Runs the training pipeline. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ # Create directories for experimental logs sample_dir = os.path.join(workdir, "samples") tf.io.gfile.makedirs(sample_dir) tb_dir = os.path.join(workdir, "tensorboard") tf.io.gfile.makedirs(tb_dir) writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. score_model = mutils.create_model(config) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) optimizer = losses.get_optimizer(config, score_model.parameters()) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) # Create checkpoints directory checkpoint_dir = os.path.join(workdir, "checkpoints") # Intermediate checkpoints to resume training after pre-emption in cloud environments checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth") tf.io.gfile.makedirs(checkpoint_dir) tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir)) # Resume training when intermediate checkpoints are detected state = restore_checkpoint(checkpoint_meta_dir, state, config.device) initial_step = int(state['step']) # Build data iterators train_ds, eval_ds, _ = datasets.get_dataset( config, uniform_dequantization=config.data.uniform_dequantization) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types # Create data normalizer and its inverse scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) # Setup SDEs if config.training.sde.lower() == 'vpsde': sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'subvpsde': sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'vesde': sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 else: raise NotImplementedError(f"SDE {config.training.sde} unknown.") # Build one-step training and evaluation functions optimize_fn = losses.optimization_manager(config) continuous = config.training.continuous reduce_mean = config.training.reduce_mean likelihood_weighting = config.training.likelihood_weighting train_step_fn = losses.get_step_fn( sde, train=True, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting) eval_step_fn = losses.get_step_fn( sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting) # Building sampling functions if config.training.snapshot_sampling: sampling_shape = (config.training.batch_size, config.data.num_channels, config.data.image_size, config.data.image_size) sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps) num_train_steps = config.training.n_iters # In case there are multiple hosts (e.g., TPU pods), only log to host 0 logging.info("Starting training loop at step %d." % (initial_step, )) for step in range(initial_step, num_train_steps + 1): # Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy. batch = torch.from_numpy(next(train_iter)['image']._numpy()).to( config.device).float() batch = batch.permute(0, 3, 1, 2) batch = scaler(batch) # Execute one training step loss = train_step_fn(state, batch) if step % config.training.log_freq == 0: logging.info("step: %d, training_loss: %.5e" % (step, loss.item())) writer.add_scalar("training_loss", loss, step) # Save a temporary checkpoint to resume training after pre-emption periodically if step != 0 and step % config.training.snapshot_freq_for_preemption == 0: save_checkpoint(checkpoint_meta_dir, state) # Report the loss on an evaluation dataset periodically if step % config.training.eval_freq == 0: eval_batch = torch.from_numpy( next(eval_iter)['image']._numpy()).to(config.device).float() eval_batch = eval_batch.permute(0, 3, 1, 2) eval_batch = scaler(eval_batch) eval_loss = eval_step_fn(state, eval_batch) logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item())) writer.add_scalar("eval_loss", eval_loss.item(), step) # Save a checkpoint periodically and generate samples if needed if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps: # Save the checkpoint. save_step = step // config.training.snapshot_freq save_checkpoint( os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state) # Generate and save samples if config.training.snapshot_sampling: ema.store(score_model.parameters()) ema.copy_to(score_model.parameters()) sample, n = sampling_fn(score_model) ema.restore(score_model.parameters()) this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step)) tf.io.gfile.makedirs(this_sample_dir) nrow = int(np.sqrt(sample.shape[0])) image_grid = make_grid(sample, nrow, padding=2) sample = np.clip( sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.np"), "wb") as fout: np.save(fout, sample) with tf.io.gfile.GFile( os.path.join(this_sample_dir, "sample.png"), "wb") as fout: save_image(image_grid, fout)