def evaluate_legacy_model(weight_files, num_game, seed, bomb, num_run=1, verbose=True): # model_lockers = [] # greedy_extra = 0 agents = [] num_player = len(weight_files) assert num_player > 1, "1 weight file per player" for weight_file in weight_files: if verbose: print("evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game)) if "sad" in weight_file or "aux" in weight_file: sad = True else: sad = False device = "cuda:0" state_dict = torch.load(weight_file) input_dim = state_dict["net.0.weight"].size()[1] hid_dim = 512 output_dim = state_dict["fc_a.weight"].size()[0] agent = r2d2.R2D2Agent(False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False).to(device) utils.load_weight(agent.online_net, weight_file, device) agents.append(agent) scores = [] perfect = 0 for i in range(num_run): _, _, score, p = evaluate( agents, num_game, num_game * i + seed, bomb, 0, sad, ) scores.extend(score) perfect += p mean = np.mean(scores) sem = np.std(scores) / np.sqrt(len(scores)) perfect_rate = perfect / (num_game * num_run) if verbose: print("score: %f +/- %f" % (mean, sem), "; perfect: ", perfect_rate) return mean, sem, perfect_rate
def load_op_model(method, idx1, idx2): """load op models, op models was trained only for 2 player """ root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # assume model saved in root/models/op folder = os.path.join(root, "models", "op", method) agents = [] for idx in [idx1, idx2]: if idx >= 0 and idx < 3: num_fc = 1 skip_connect = False elif idx >= 3 and idx < 6: num_fc = 1 skip_connect = True elif idx >= 6 and idx < 9: num_fc = 2 skip_connect = False else: num_fc = 2 skip_connect = True weight_file = os.path.join(folder, f"M{idx}.pthw") if not os.path.exists(weight_file): print(f"Cannot find weight at: {weight_file}") assert False device = "cuda:0" state_dict = torch.load(weight_file) input_dim = state_dict["net.0.weight"].size()[1] hid_dim = 512 output_dim = state_dict["fc_a.weight"].size()[0] agent = r2d2.R2D2Agent( False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False, num_fc_layer=num_fc, skip_connect=skip_connect, ).to(device) utils.load_weight(agent.online_net, weight_file, device) agents.append(agent) return agents
def setup(data_folder): np.random.seed(0) torch.cuda.manual_seed_all(0) torch.manual_seed(0) codec = get_encoder() dataset = NewsDataset(path=data_folder, ctx_length=128, codec=codec, start_from_zero=True) config = GPT2Config() model = GPT2LMHeadModel(config) if not os.path.exists('gpt2-pytorch_model.bin'): print("Downloading GPT-2 checkpoint...") url = 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin' r = requests.get(url, allow_redirects=True) open('gpt2-pytorch_model.bin', 'wb').write(r.content) model = load_weight( model, torch.load('gpt2-pytorch_model.bin', map_location=device)) model = model.to(device) model.eval() return codec, model, dataset, config
def predict(model_type, finetune_dataset, input_path, output_path, probability_output, batch_size, gpu=True): model = fastsal.fastsal(pretrain_mode=False, model_type=model_type) state_dict, opt_state = load_weight('weights/{}_{}.pth'.format(finetune_dataset, model_type), remove_decoder=False) model.load_state_dict(state_dict) if gpu: model.cuda() simple_data = img_dataset(input_path, output_path) simple_loader = DataLoader(simple_data, batch_size=batch_size, shuffle=False, num_workers=4) for x, original_size_list, output_path_list in simple_loader: if gpu: x = x.float().cuda() y = model(x) if not probability_output: y = nn.Sigmoid()(y) if gpu: y = y.detach().cpu() y = y.numpy() for i, prediction in enumerate(y[:, 0, :, :]): img_output_path = output_path_list[i] original_size = original_size_list[i].numpy() print(img_output_path) if not probability_output: img_data = post_process_png(prediction, original_size) cv2.imwrite(img_output_path, img_data) else: img_data = post_process_probability2(prediction, original_size) np.save(img_output_path.split('.')[0], img_data)
def main(): # random seed seed = 1234 np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # load dataset if args.dataset[0] == 'deepfashion': ds = pd.read_csv('./Anno/df_info.csv') from dataset import DeepFashionDataset as DataManager elif args.dataset[0] == 'fld': ds = pd.read_csv('./Anno/fld_info.csv') from dataset import FLDDataset as DataManager else: raise ValueError print('dataset : %s' % (args.dataset[0])) if not args.evaluate: train_dm = DataManager(ds[ds['evaluation_status'] == 'train'], root=args.root) train_dl = DataLoader(train_dm, batch_size=args.batchsize, shuffle=True) if os.path.exists('models') is False: os.makedirs('models') test_dm = DataManager(ds[ds['evaluation_status'] == 'test'], root=args.root) test_dl = DataLoader(test_dm, batch_size=args.batchsize, shuffle=False) val_dm = DataManager(ds[ds['evaluation_status'] == 'val'], root=args.root) val_dl = DataLoader(val_dm, batch_size=args.batchsize, shuffle=True) # Load model print("Load the model...") net = torch.nn.DataParallel(StructNet().cuda()) if not args.weight_file == None: weights = torch.load(args.weight_file) if args.update_weight: weights = utils.load_weight(net, weights) net.load_state_dict(weights) # evaluate only if args.evaluate: print("Evaluation only") test(net, test_dl, -1) return # learning parameters optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, 0.1) print('Start training') for epoch in range(args.epoch): train(net, optimizer, train_dl, epoch) lr_scheduler.step() test(net, val_dl, epoch)
def main(model_type, batch_size, dataset_name, dataset_path, size, width_bigger, pretrain_path, save_path, probability_output): # Datasets for SALICON if dataset_name == 'salicon': ds_validate = dataset.Salicon(dataset_path, mode='test', type=['vgg_img'], size=(size, )) elif dataset_name == 'mit300': ds_validate = mit300.dataset(dataset_path, type=('vgg_img'), size=(size, )) ds_validate.renew_list(width_bigger=width_bigger) elif dataset_name == 'mit1003': ds_validate = mit1003.dataset(dataset_path, type=('vgg_img'), size=(size, )) ds_validate.renew_list(width_bigger=width_bigger) elif dataset_name == 'dhf1k': ds_validate = dhf1k.dataset(dataset_path, mode='test', type=('vgg_img'), size=(size, )) file_list = ds_validate.list_names if pretrain_path: state_dict, opt_state = load_weight(pretrain_path, remove_decoder=False) else: print('please specify trained models.') exit() model = student_teacher.salgan_teacher_student( False, model_type, use_probability_gt=probability_output) model.student_net.load_state_dict(state_dict) model.cuda() dataloader = { 'val': DataLoader(ds_validate, batch_size=batch_size, shuffle=False, num_workers=4) } print('--------------------------------------------->>>>>>') model.eval() with t.no_grad(): train_one(model, dataloader, file_list, 'val', save_path, probability_output) print('--------------------------------------------->>>>>>')
def load_sad_model(weight_files): agents = [] for weight_file in weight_files: if verbose: print( "evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game) ) if "sad" in weight_file or "aux" in weight_file: sad = True else: sad = False device = "cuda:0" state_dict = torch.load(weight_file) input_dim = state_dict["net.0.weight"].size()[1] hid_dim = 512 output_dim = state_dict["fc_a.weight"].size()[0] agent = r2d2.R2D2Agent( False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False ).to(device) utils.load_weight(agent.online_net, weight_file, device) agents.append(agent) return agents
def setup(n_enc_layer=1): np.random.seed(0) torch.cuda.manual_seed_all(0) torch.manual_seed(0) codec = get_encoder() config = GPT2Config(n_enc_layer=n_enc_layer) model = GPT2LMHeadModel(config) if not os.path.exists('../gpt2-pytorch_model.bin'): print("Downloading GPT-2 checkpoint...") url = 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin' r = requests.get(url, allow_redirects=True) open('../gpt2-pytorch_model.bin', 'wb').write(r.content) model = load_weight( model, torch.load('../gpt2-pytorch_model.bin', map_location=device)) model = model.to(device) return codec, model, config
def exportToONNX(weights, model_type, output): # Fastsal - Type Coco A model = fastsal.fastsal(pretrain_mode=False, model_type=model_type) state_dict, opt_state = load_weight(weights, remove_decoder=False) model.load_state_dict(state_dict) model.eval() torch_out = model(x) torch.onnx.export( model, # model being run x, # model input (or a tuple for multiple inputs) output, # where to save the model (can be a file or file-like object) export_params= True, # store the trained parameter weights inside the model file opset_version=11, # the ONNX version to export the model to do_constant_folding= True, # whether to execute constant folding for optimization input_names=["input"], # the model's input names output_names=["output"], # the model's output names )
def start_eval(model_type, batch_size, dataset_name, dataset_path, model_path=None): if dataset_name == 'salicon': ds_validate = dataset.Salicon(dataset_path, mode='val', type=['vgg_img', 'sal_img', 'fixation'], size=[(192, 256), (480, 640), (480, 640)]) elif dataset_name == 'mit1003': ds_validate = mit1003.dataset(dataset_path, type=['vgg_img', 'sal_img', 'fixation'], size=[(192, 256), None, None]) if model_path: state_dict, opt_state = load_weight(model_path, remove_decoder=False) else: print('please specify trained models.') exit() model = student_teacher.salgan_teacher_student(False, model_type) model.student_net.load_state_dict(state_dict) model.cuda() # model.generator.load_state_dict(state_dict['state_dict']) # Dataloaders dataloader = { 'val': DataLoader(ds_validate, batch_size=batch_size, shuffle=False, num_workers=4) } model.eval() with t.no_grad(): train_one(model, dataloader, 'val')
def start_train(batch_size, dataset_name, dataset_path, teacher_path, direct, model_name): if dataset_name == 'salicon': dataloader = salicon_data(batch_size, dataset_path) elif dataset_name == 'coco': dataloader = coco_data(batch_size, dataset_path) model = student_teacher.salgan_teacher_student(True, 'C', teacher_path) model.cuda() lr = 0.01 lr_decay = 0.1 optimizer = model.get_optimizer(lr) smallest_val = None best_epoch = None for epoch in range(0, 100, 1): model.train() loss_train, model = train_one(model, dataloader, optimizer, 'train') print('{} loss train {}, lr {}'.format(epoch, loss_train, lr)) print('--------------------------------------------->>>>>>') model.eval() loss_val, model = train_one(model, dataloader, optimizer, 'val') print('--------------------------------------------->>>>>>') print('{} loss val {}'.format(epoch, loss_val)) smallest_val, best_epoch, model, optimizer = save_weight( smallest_val, best_epoch, loss_val, epoch, direct, model_name, model, optimizer) if epoch == 15 or epoch == 30 or epoch == 60: path = '{}/{}/{}_{:f}.pth'.format(direct, model_name, best_epoch, smallest_val) state_dict, opt_state = load_weight(path, remove_decoder=False) model.student_net.load_state_dict(state_dict) # optimizer.load_state_dict(state_dict['optimizer']) for param_group in optimizer.param_groups: param_group['lr'] *= lr_decay lr = lr * lr_decay
def run(arg): torch.manual_seed(7) np.random.seed(7) print("lr %f, epoch_num %d, decay_rate %f gamma %f" % (arg.lr, arg.epochs, arg.decay, arg.gamma)) print("====>Loading data") if arg.dataset == 'mnist': train_data = get_mnist_dataset("train", arg.batch_size) G_output_dim = config.mnist_G_output_dim D_input_dim = config.mnist_D_input_dim elif arg.dataset == 'cifar': train_data = get_cifar_dataset("train", arg.batch_size) G_output_dim = config.cifar_G_output_dim D_input_dim = config.cifar_D_input_dim print("====>Building model") if arg.net == "DCGAN": g_net = DC_Generator(config.G_input_dim, config.num_filters, G_output_dim, verbose=False) d_net = DC_Discriminator(D_input_dim, config.num_filters[::-1], config.D_output_dim, verbose=False) # g_net = Generator() # d_net = Discriminator() g_net = g_net.to(config.device) d_net = d_net.to(config.device) g_optimizer = optim.Adam(g_net.parameters(), lr=arg.lr, betas=(0.5, 0.9)) d_optimizer = optim.Adam(d_net.parameters(), lr=arg.lr, betas=(0.5, 0.9)) if arg.mul_gpu: g_net = nn.DataParallel(g_net) d_net = nn.DataParallel(d_net) if arg.checkpoint is not None: print("load pre train model") g_pretrained_dict = torch.load(os.path.join(config.checkpoint_path, arg.dataset + "_" + arg.net + "_g_net_" + arg.checkpoint + '.pth')) d_pretrained_dict = torch.load(os.path.join(config.checkpoint_path, arg.dataset + "_" + arg.net + "_d_net_" + arg.checkpoint + '.pth')) g_net = load_weight(g_pretrained_dict, g_net) d_net = load_weight(d_pretrained_dict, d_net) # 日志系统 logger = get_logger() criterion = nn.BCELoss() print('Total params: %.2fM' % ((sum(p.numel() for p in g_net.parameters()) + sum(p.numel() for p in d_net.parameters())) / 1000000.0)) print("start training: ", datetime.now()) start_epoch = 0 fix_noise = torch.autograd.Variable(torch.randn(config.nrow ** 2, config.G_input_dim).view(-1, config.G_input_dim, 1, 1).cuda()) g_losses = [] d_losses = [] for epoch in range(start_epoch, arg.epochs): prev_time = datetime.now() g_loss, d_loss = train(train_data, g_net, d_net, criterion, g_optimizer, d_optimizer, epoch, logger) now_time = datetime.now() time_str = count_time(prev_time, now_time) print("train: current (%d/%d) batch g_loss is %f d_loss is %f time " "is %s" % (epoch, arg.epochs, g_loss, d_loss, time_str)) g_losses.append(g_loss) d_losses.append(d_loss) plot_loss(d_losses, g_losses, epoch, arg.net, arg.dataset) plot_sample(g_net, fix_noise, epoch, net_name=arg.net, dataset_name=arg.dataset) if epoch % 2 == 0: save_checkpoint(arg.dataset, arg.net, epoch, g_net, d_net) save_checkpoint(arg.dataset, arg.net, arg.epochs, g_net, d_net)
def evaluate_legacy_model( weight_files, num_game, seed, bomb, agent_args, args, num_run=1, gen_cross_play=False, verbose=True, ): agents = [] num_player = len(weight_files) assert num_player > 1, "1 weight file per player" env_sad = False for i, weight_file in enumerate(weight_files): if verbose: print( "evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game) ) if "sad" in weight_file: sad = True env_sad = True else: sad = False device = "cuda:0" state_dict = torch.load(weight_file) input_dim = state_dict["net.0.weight"].size()[1] output_dim = state_dict["fc_a.weight"].size()[0] if gen_cross_play: agent_name = weight_file.split("/")[-1].split(".")[0] with open(f"{args.weight_1_dir}/{agent_name}.txt", "r") as f: agent_args = {**json.load(f)} else: learnable_pretrain = True if i == 0: learnable_agent_name = agent_args["load_learnable_model"] if learnable_agent_name != "": agent_args_file = f"{learnable_agent_name[:-4]}txt" else: learnable_pretrain = False else: agent_args_file = f"{weight_file[:-4]}txt" if learnable_pretrain == True: with open(agent_args_file, "r") as f: agent_args = {**json.load(f)} rnn_type = agent_args["rnn_type"] rnn_hid_dim = agent_args["rnn_hid_dim"] num_fflayer = agent_args["num_fflayer"] num_rnn_layer = agent_args["num_rnn_layer"] if rnn_type == "lstm": import r2d2_lstm as r2d2 elif rnn_type == "gru": import r2d2_gru as r2d2 agent = r2d2.R2D2Agent( False, 3, 0.999, 0.9, device, input_dim, rnn_hid_dim, output_dim, num_fflayer, num_rnn_layer, 5, False, sad=sad, ).to(device) utils.load_weight(agent.online_net, weight_file, device) agents.append(agent) scores = [] perfect = 0 for i in range(num_run): if args.is_rand: random.shuffle(agents) _, _, score, p = evaluate( agents, num_game, num_game * i + seed, bomb, 0, env_sad, ) scores.extend(score) perfect += p mean = np.mean(scores) sem = np.std(scores) / np.sqrt(len(scores)) perfect_rate = perfect / (num_game * num_run) if verbose: print("score: %f +/- %f" % (mean, sem), "; perfect: ", perfect_rate) return mean, sem, perfect_rate
#!/usr/bin/env python # -*- coding: utf-8 -*- # @File : convert_weight.py # @Author: ruixi L # @Date : 2019/11/13 import tensorflow as tf from model import Mymodel from utils import load_weight weight_path = r'./model.pb' inputs = tf.placeholder(tf.float32, shape=[None, 784], name='x') model = Mymodel(class_num=10) logits = model.forward(inputs) name_list = tf.trainable_variables() with tf.Session() as sess: load_op = load_weight(name_list, weight_path) for i in range(len(load_op)): sess.run(tf.assign(name_list[i], load_op[i])) tf.train.Saver().save(sess, './checkpoint/mymodel') print('权重转换完成')
args.train_device, learnable_games[0].feature_size(), rnn_hid_dim, learnable_games[0].num_action(), num_fflayer, num_rnn_layer, args.hand_size, False, sad=learnable_sad, ) learnable_agent.sync_target_with_online() if args.load_learnable_model: print("*****loading pretrained model for learnable agent *****") utils.load_weight(learnable_agent.online_net, args.load_learnable_model, args.train_device) print("*****done*****") if args.resume_cont_training: print("***** resuming continual training ... ") learnable_agent_ckpts = glob.glob(f"{args.save_dir}/*_zero_shot.pthw") learnable_agent_ckpts.sort(key=os.path.getmtime) print("restoring from ... ", learnable_agent_ckpts[-1]) utils.load_weight(learnable_agent.online_net, learnable_agent_ckpts[-1], args.train_device) epoch_restore = int(learnable_agent_ckpts[-1].split("/")[-1].split(".") [0].split("_")[1][5:]) print("epoch restore is ... ", epoch_restore) learnable_agent = learnable_agent.to(args.train_device) print(learnable_agent)
args.multi_step, args.gamma, args.eta, args.train_device, games[0].feature_size(), args.rnn_hid_dim, games[0].num_action(), args.num_lstm_layer, args.hand_size, False, # uniform priority ) agent.sync_target_with_online() if args.load_model: print("*****loading pretrained model*****") utils.load_weight(agent.online_net, args.load_model, args.train_device) print("*****done*****") agent = agent.to(args.train_device) optim = torch.optim.Adam(agent.online_net.parameters(), lr=args.lr, eps=args.eps) print(agent) eval_agent = agent.clone(args.train_device, {"vdn": False}) replay_buffer = rela.RNNPrioritizedReplay( args.replay_buffer_size, args.seed, args.priority_exponent, args.priority_weight, args.prefetch,
import model.fastSal as fastsal from utils import load_weight import torch if __name__ == '__main__': coco_c = 'weights/coco_C.pth' # coco_C coco_a = 'weights/coco_A.pth' # coco_A salicon_c = 'weights/salicon_C.pth' # salicon_C salicon_a = 'weights/salicon_A.pth' # coco_A x = torch.zeros((10, 3, 192, 256)) model = fastsal.fastsal(pretrain_mode=False, model_type='A') state_dict, opt_state = load_weight(coco_a, remove_decoder=False) model.load_state_dict(state_dict) y = model(x) print(y.shape) model = fastsal.fastsal(pretrain_mode=False, model_type='A') state_dict, opt_state = load_weight(salicon_a, remove_decoder=False) model.load_state_dict(state_dict) y = model(x) print(y.shape) model = fastsal.fastsal(pretrain_mode=False, model_type='C') state_dict, opt_state = load_weight(coco_c, remove_decoder=False) model.load_state_dict(state_dict) y = model(x) print(y.shape) model = fastsal.fastsal(pretrain_mode=False, model_type='C')
def main(): # random seed seed = 1234 np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) from dataset_df2_loader import DeepFashionDataset as DataManager with open("./data/train/deepfashion2.json", 'r') as infile: ds = json.load(infile) ds = ds['annotations'][0:5] print("dataset", len(ds), args.batchsize, args.epoch) print('dataset : %s' % (args.dataset[0])) if not args.evaluate: train_dm = DataManager(ds, root=args.root) train_dl = DataLoader(train_dm, batch_size=args.batchsize, shuffle=True) if os.path.exists('models') is False: os.makedirs('models') with open("./data/validation/deepfashion2_datafile_8.json", 'r') as infile: test_data = json.load(infile) test_dm = DataManager( test_data['annotations'][0:5], root= "/media/chintu/bharath_ext_hdd/Bharath/Segmentation/Landmark detection/GLE_FLD-master/data/validation/image/" ) test_dl = DataLoader(test_dm, batch_size=args.batchsize, shuffle=False) # Load model print("Load the model...") use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print("device:", device) net = torch.nn.DataParallel(Network(dataset=args.dataset, flag=args.glem)).to(device) if not args.weight_file == None: weights = torch.load(args.weight_file) if args.update_weight: weights = utils.load_weight(net, weights) net.load_state_dict(weights) # evaluate only if args.evaluate: print("Evaluation only") test(net, test_dl, 0) return # learning parameters optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.1) print('Start training') for epoch in range(args.epoch): lr_scheduler.step() train(net, optimizer, train_dl, epoch) test(net, test_dl, epoch)
def start_train(model_type, batch_size, dataset_name, dataset_path, teacher_path, direct, model_name, pretrain_path=None, pseudo_path=None): if dataset_name == 'salicon': dataloader = salicon_data(batch_size, dataset_path) use_gt = True use_pseudo_gt = False use_teacher = True use_probability_gt = False elif dataset_name == 'coco': dataloader = coco_data(batch_size, dataset_path, input_type=['npy_img'], pseudo_path=pseudo_path) use_gt = False use_pseudo_gt = True use_teacher = False use_probability_gt = True if pretrain_path: state_dict, opt_state = load_weight(pretrain_path, remove_decoder=False) else: state_dict = None model = student_teacher.salgan_teacher_student( False, model_type, teacher_path, state_dict, use_gt=use_gt, use_teacher=use_teacher, use_probability_gt=use_probability_gt, use_pseudo_gt=use_pseudo_gt) model.cuda() lr = 0.01 lr_decay = 0.1 optimizer = model.get_optimizer(lr) smallest_val = None best_epoch = None for epoch in range(0, 100, 1): #with t.no_grad(): # metrics = get_saliency_metrics(dataloader['metric'], model, N=100) model.train() loss_train, model = train_one(model, dataloader, optimizer, 'train', use_gt=use_gt, use_teacher=use_teacher, use_pseudo_gt=use_pseudo_gt) print('{} loss train {}, lr {}'.format(epoch, loss_train, lr)) print('--------------------------------------------->>>>>>') model.eval() loss_val, model = train_one(model, dataloader, optimizer, 'val', use_gt=use_gt, use_teacher=use_teacher, use_pseudo_gt=use_pseudo_gt) print('--------------------------------------------->>>>>>') print('{} loss val {}'.format(epoch, loss_val)) smallest_val, best_epoch, model, optimizer = save_weight( smallest_val, best_epoch, loss_val, epoch, direct, model_name, model, optimizer) if epoch == 15 or epoch == 30 or epoch == 60: path = '{}/{}/{}_{:f}.pth'.format(direct, model_name, best_epoch, smallest_val) state_dict, opt_state = load_weight(path, remove_decoder=False) model.student_net.load_state_dict(state_dict) for param_group in optimizer.param_groups: param_group['lr'] *= lr_decay lr = lr * lr_decay
s = obs["s"].unsqueeze(0) assert s.size(2) == self.in_dim x = self.net(s) o, (h, c) = self.lstm(x, (h0, c0)) a = self.fc_a(o).squeeze(0) return { "a": a, "h0": h.transpose(0, 1).contiguous(), "c0": c.transpose(0, 1).contiguous(), } ## main program ## parser = argparse.ArgumentParser(description="") parser.add_argument("--model", type=str, default=None) args = parser.parse_args() device = "cuda" state_dict = torch.load(args.model) in_dim = state_dict["net.0.weight"].size()[1] out_dim = state_dict["fc_a.weight"].size()[0] print("after loading model") search_model = LSTMNet(device, in_dim, 512, out_dim, 2, 5) utils.load_weight(search_model, args.model, device) save_path = args.model.rsplit(".", 1)[0] + ".sparta" print("saving model to:", save_path) torch.jit.save(search_model, save_path)
def text_generator(state_dict): parser = argparse.ArgumentParser() parser.add_argument("--text", type=str, required=True) parser.add_argument("--quiet", type=bool, default=False) parser.add_argument("--nsamples", type=int, default=1) parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') parser.add_argument("--batch_size", type=int, default=-1) parser.add_argument("--length", type=int, default=-1) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top_k", type=int, default=40) args = parser.parse_args() if args.quiet is False: print(args) if args.batch_size == -1: args.batch_size = 1 assert args.nsamples % args.batch_size == 0 seed = random.randint(0, 2147483647) np.random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load Model enc = get_encoder() config = GPT2Config() model = GPT2LMHeadModel(config) model = load_weight(model, state_dict) model.to(device) model.eval() if args.length == -1: args.length = config.n_ctx // 2 elif args.length > config.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % config.n_ctx) print(args.text) context_tokens = enc.encode(args.text) generated = 0 for _ in range(args.nsamples // args.batch_size): out = sample_sequence( model=model, length=args.length, context=context_tokens if not args.unconditional else None, start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None, batch_size=args.batch_size, temperature=args.temperature, top_k=args.top_k, device=device) out = out[:, len(context_tokens):].tolist() for i in range(args.batch_size): generated += 1 text = enc.decode(out[i]) if args.quiet is False: print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text)