def prepare_model(inference_config, device): stylization_model = TransformerNet().to(device) training_state = torch.load(os.path.join(inference_config['model_binaries_path'], inference_config['model_name']), map_location=torch.device('cpu')) state_dict = training_state['state_dict'] stylization_model.load_state_dict(state_dict, strict=True) stylization_model.eval() return stylization_model
def test(): pretrain, code_id, word_id = source_prepare() print('%d different words, %d different codes' % (len(word_id), len(code_id)), flush=True) Acc = 0. for fold in range(5): Net = TransformerNet(Pretrain_type, pretrain, Max_seq_len, Embedding_size, Inner_hid_size, len(code_id), D_k, D_v, dropout_ratio=Dropout, num_layers=Num_layers, num_head=Num_head, Freeze=Freeze_emb).cuda() Net.load_state_dict(torch.load(SAVE_DIR + 'data1_false_biobert_' + str(fold) + '_599')) Net.eval() # test_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.test.txt' test_file = DATA_path + 'test_' + str(fold) + '.csv' test_data = tokenizer(word_id, code_id, test_file, pretrain_type=Pretrain_type) print('Fold %d: %d test data' % (fold, len(test_data.data))) print('max length: %d' % test_data.max_length, flush=True) test_data.reset_epoch() test_correct = 0 i = 0 while not test_data.epoch_finish: seq, label, seq_length, mask, seq_pos, standard_emb = test_data.get_batch(1) results = Net(seq, seq_pos, standard_emb) _, idx = results.max(1) test_correct += len((idx == label).nonzero()) i += 1 assert i == len(test_data.data) test_accuracy = float(test_correct) / float(i) print('[fold %d] test: %d correct, %.4f accuracy' % (fold, test_correct, test_accuracy), flush=True) Acc += test_accuracy del test_data gc.collect() print('finial validation accuracy: %.4f' % (Acc / 5), flush=True)
def train(): pretrain, code_id, word_id = source_prepare() print('%d different words, %d different codes' % (len(word_id), len(code_id)), flush=True) criterion = nn.CrossEntropyLoss() Acc = 0. for fold in range(5): Net = TransformerNet(Pretrain_type, pretrain, Max_seq_len, Embedding_size, Inner_hid_size, len(code_id), D_k, D_v, dropout_ratio=Dropout, num_layers=Num_layers, num_head=Num_head, Freeze=Freeze_emb).cuda() optimizer = optim.Adam(Net.parameters(), lr=Learning_rate, eps=1e-08, weight_decay=Weight_decay) # train_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.train.txt' # val_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.validation.txt' train_file = DATA_path + 'trainsplit_' + str(fold) + '.csv' val_file = DATA_path + 'val_' + str(fold) + '.csv' train_data = tokenizer(word_id, code_id, train_file, pretrain_type=Pretrain_type) val_data = tokenizer(word_id, code_id, val_file, pretrain_type=Pretrain_type) print('Fold %d: %d training data, %d validation data' % (fold, len(train_data.data), len(val_data.data))) print('max length: %d %d' % (train_data.max_length, val_data.max_length), flush=True) for e in range(Epoch): train_data.reset_epoch() Net.train() while not train_data.epoch_finish: optimizer.zero_grad() seq, label, seq_length, mask, seq_pos, standard_emb = train_data.get_batch(Batch_size) results = Net(seq, seq_pos, standard_emb) loss = criterion(results, label) loss.backward() optimizer.step() if (e + 1) % Val_every == 0: Net.eval() train_data.reset_epoch() train_correct = 0 i = 0 while not train_data.epoch_finish: seq, label, seq_length, mask, seq_pos, standard_emb = train_data.get_batch(Batch_size) results = Net(seq, seq_pos, standard_emb) _, idx = results.max(1) train_correct += len((idx == label).nonzero()) i += Batch_size # assert i == len(train_data.data) train_accuracy = float(train_correct) / float(i) val_data.reset_epoch() val_correct = 0 i = 0 while not val_data.epoch_finish: seq, label, seq_length, mask, seq_pos, standard_emb = val_data.get_batch(Batch_size) results = Net(seq, seq_pos, standard_emb) _, idx = results.max(1) val_correct += len((idx == label).nonzero()) i += Batch_size # assert i == len(val_data.data) val_accuracy = float(val_correct) / float(len(val_data.data)) print('[fold %d epoch %d] training loss: %.4f, % d correct, %.4f accuracy;' ' validation: %d correct, %.4f accuracy' % (fold, e, loss.item(), train_correct, train_accuracy, val_correct, val_accuracy), flush=True) torch.save(Net.state_dict(), SAVE_DIR + 'data1_false_biobert_' + str(fold) + '_' + str(e)) if (e + 1) % LR_decay_epoch == 0: adjust_learning_rate(optimizer, LR_decay) print('learning rate decay!', flush=True) Acc += val_accuracy del train_data, val_data gc.collect() print('finial validation accuracy: %.4f' % (Acc / 5))
def train(content_img_name=None, style_img_name=None, features=None): transformer = TransformerNet() # features = Vgg16() lr = 0.001 weight_content = 1e5 weight_style = 1e10 optimizer = torch.optim.Adam(transformer.parameters(), lr) mse_loss = torch.nn.MSELoss() style = load_image(style_img_name) style = style_transform(style) style = style.unsqueeze(0) style_v = Variable(style) style_v = normalize_batch(style_v) features_style = features(style_v) gram_style = [gram_matrix(y) for y in features_style] transformer.train() x = load_image(content_img_name) x = content_transform(x) x = x.unsqueeze(0) x = Variable(x) if use_cuda: transformer.cuda() features.cuda() x = x.cuda() gram_style = [gram.cuda() for gram in gram_style] # training count = 0 log_name = './logs/log_exp_{}.txt'.format(exp_num) log = [] while count < iteration_total: optimizer.zero_grad() y = transformer(x) y = normalize_batch(y) x = normalize_batch(x) features_y = features(y) features_x = features(x) loss_content = mse_loss(features_y[1], features_x[1]) loss_style = 0. for ft_y, gm_s in zip(features_y, gram_style): gm_y = gram_matrix(ft_y) loss_style = loss_style + mse_loss(gm_y, gm_s) total_loss = weight_content * loss_content + weight_style * loss_style total_loss.backward() optimizer.step() # log show count += 1 msg = '{}\titeration: {}\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}\n'.format( time.ctime(), count, loss_content.item(), loss_style.item(), total_loss.item()) log.append(msg) if count % 50 == 0: print(''.join(log)) with open(log_name, 'a') as f: f.writelines(''.join(log)) log.clear() # save model transformer.eval() if use_cuda: transformer.cpu() save_model_name = './models/model_exp_{}.pt'.format(exp_num) torch.save(transformer.state_dict(), save_model_name)
def train(args): device = torch.device("cpu") np.random.seed(args.seed) torch.manual_seed(args.seed) # load dataset transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) # load style image style_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) # load network transformer = TransformerNet().to(device) vgg = Vgg16(requires_grad=False).to(device) # define optimizer and loss function optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] for e in range(args.epochs): transformer.train() count = 0 for batch_id, (x, _) in enumerate(train_loader): count += len(x) optimizer.zero_grad() image_original = x.to(device) image_transformed = transformer(x) image_original = utils.normalize_batch(image_original) image_transformed = utils.normalize_batch(image_transformed) # extract features for compute content loss features_original= vgg(image_original) features_transformed = vgg(image_transformed) content_loss = args.content_weight * mse_loss(features_transformed.relu3_3, features_original.relu3_3) # extract features for compute style loss style_loss = 0. for ft_y, gm_s in zip(features_transformed, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:len(x), :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() if (batch_id + 1) % 200 == 0: print("Epoch {}:[{}/{}]".format(e + 1, count, len(train_dataset))) # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str(args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") if args.backbone == "vgg": content_layer = ['relu_4'] style_layer = ['relu_2', 'relu_4', 'relu_7', 'relu_10'] elif args.backbone == "resnet": content_layer = ["conv_3"] style_layer = ["conv_1", "conv_2", "conv_3", "conv_4"] total_layer = list(dict.fromkeys(content_layer + style_layer)) np.random.seed(args.seed) torch.manual_seed(args.seed) transform = transforms.Compose([ transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(args.dataset, transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() if args.backbone == "vgg": loss_model = vgg16().eval().to(device) elif args.backbone == "resnet": loss_model = resnet().eval().to(device) style_transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) feature_style = loss_model(utils.normalize_batch(style), style_layer) gram_style = { key: utils.gram_matrix(val) for key, val in feature_style.items() } for e in range(args.epochs): transformer.train() agg_content_loss = 0. agg_style_loss = 0. count = 0 for batch_id, (x, _) in enumerate(train_loader): n_batch = len(x) count += n_batch optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) feature_y = loss_model(y, total_layer) feature_x = loss_model(x, content_layer) content_loss = 0. for layer in content_layer: content_loss += args.content_weight * mse_loss( feature_y[layer], feature_x[layer]) style_loss = 0. for name in style_layer: gm_y = utils.gram_matrix(feature_y[name]) style_loss += mse_loss(gm_y, gram_style[name][:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() agg_content_loss += content_loss.item() agg_style_loss += style_loss.item() if (batch_id + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format( time.ctime(), e + 1, count, len(train_dataset), agg_content_loss / (batch_id + 1), agg_style_loss / (batch_id + 1), (agg_content_loss + agg_style_loss) / (batch_id + 1)) print(mesg) if args.checkpoint_model_dir is not None and ( batch_id + 1) % args.checkpoint_interval == 0: transformer.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str( e) + "_batch_id_" + str(batch_id + 1) + ".pth" ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename) torch.save(transformer.state_dict(), ckpt_model_path) transformer.to(device).train() # save model transformer.eval().cpu() save_model_filename = "epoch_" + str(args.epochs) + "_" + str( time.ctime()).replace(' ', '_') + "_" + str( args.content_weight) + "_" + str(args.style_weight) + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(transformer.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)