def train(self): Tools.print() Tools.print("Training...") # Init Update try: self.proto_net.eval() self.ic_model.eval() Tools.print("Init label {} .......") self.produce_class.reset() for task_data, task_labels, task_index in tqdm(self.task_train_loader): ic_labels = RunnerTool.to_cuda(task_index[:, -1]) task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels) log_p_y, query_features = self.proto(task_data) ic_out_logits, ic_out_l2norm = self.ic_model(query_features) self.produce_class.cal_label(ic_out_l2norm, ic_labels) pass Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2)) finally: pass for epoch in range(Config.train_epoch): self.proto_net.train() self.ic_model.train() Tools.print() self.produce_class.reset() all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0 for task_data, task_labels, task_index in tqdm(self.task_train_loader): ic_labels = RunnerTool.to_cuda(task_index[:, -1]) task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels) ########################################################################### # 1 calculate features log_p_y, query_features = self.proto(task_data) ic_out_logits, ic_out_l2norm = self.ic_model(query_features) # 2 ic_targets = self.produce_class.get_label(ic_labels) self.produce_class.cal_label(ic_out_l2norm, ic_labels) # 3 loss loss_fsl = -(log_p_y * task_labels).sum() / task_labels.sum() * Config.loss_fsl_ratio loss_ic = self.ic_loss(ic_out_logits, ic_targets) * Config.loss_ic_ratio loss = loss_fsl + loss_ic all_loss += loss.item() all_loss_fsl += loss_fsl.item() all_loss_ic += loss_ic.item() # 4 backward self.proto_net.zero_grad() self.ic_model.zero_grad() loss.backward() # torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5) # torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5) self.proto_net_optim.step() self.ic_model_optim.step() ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} lr:{}".format( epoch + 1, all_loss / len(self.task_train_loader), all_loss_fsl / len(self.task_train_loader), all_loss_ic / len(self.task_train_loader), self.proto_net_scheduler.get_last_lr())) Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2)) self.proto_net_scheduler.step() self.ic_model_scheduler.step() ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: self.proto_net.eval() self.ic_model.eval() self.test_tool_ic.val(epoch=epoch) val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.proto_net.state_dict(), Config.pn_dir) torch.save(self.ic_model.state_dict(), Config.ic_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def _train_epoch(self): self.model.eval() # 统计 th_num = 25 epoch_loss, epoch_loss1, epoch_loss2, nb_data = 0, 0, 0, 0 epoch_mae, epoch_prec, epoch_recall = 0.0, np.zeros( shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6 epoch_mae2, epoch_prec2, epoch_recall2 = 0.0, np.zeros( shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6 # Run iter_size = 10 self.model.zero_grad() tr_num = len(self.train_loader) for i, (images, _, labels_sod, batched_graph, batched_pixel_graph, segments, _, _) in enumerate(self.train_loader): # Data images = images.float().to(self.device) labels = batched_graph.y.to(self.device) labels_sod = torch.unsqueeze(torch.Tensor(labels_sod), dim=1).to(self.device) batched_graph.batch = batched_graph.batch.to(self.device) batched_graph.edge_index = batched_graph.edge_index.to(self.device) batched_pixel_graph.batch = batched_pixel_graph.batch.to( self.device) batched_pixel_graph.edge_index = batched_pixel_graph.edge_index.to( self.device) batched_pixel_graph.data_where = batched_pixel_graph.data_where.to( self.device) gcn_logits, gcn_logits_sigmoid, _, sod_logits, sod_logits_sigmoid = self.model.forward( images, batched_graph, batched_pixel_graph) loss_fuse1 = F.binary_cross_entropy_with_logits(sod_logits, labels_sod, reduction='sum') loss_fuse2 = F.binary_cross_entropy_with_logits(gcn_logits, labels, reduction='sum') # loss = (loss_fuse1 + loss_fuse2) / iter_size loss = loss_fuse1 / iter_size + 2 * loss_fuse2 # loss = loss_fuse1 / iter_size loss.backward() if (i + 1) % iter_size == 0: self.optimizer.step() self.optimizer.zero_grad() pass labels_val = labels.cpu().detach().numpy() labels_sod_val = labels_sod.cpu().detach().numpy() gcn_logits_sigmoid_val = gcn_logits_sigmoid.cpu().detach().numpy() sod_logits_sigmoid_val = sod_logits_sigmoid.cpu().detach().numpy() # Stat nb_data += images.size(0) epoch_loss += loss.detach().item() epoch_loss1 += loss_fuse1.detach().item() epoch_loss2 += loss_fuse2.detach().item() # cal 1 mae = self._eval_mae(sod_logits_sigmoid_val, labels_sod_val) prec, recall = self._eval_pr(sod_logits_sigmoid_val, labels_sod_val, th_num) epoch_mae += mae epoch_prec += prec epoch_recall += recall # cal 2 mae2 = self._eval_mae(gcn_logits_sigmoid_val, labels_val) prec2, recall2 = self._eval_pr(gcn_logits_sigmoid_val, labels_val, th_num) epoch_mae2 += mae2 epoch_prec2 += prec2 epoch_recall2 += recall2 # Print if i % self.train_print_freq == 0: Tools.print( "{:4d}-{:4d} loss={:.4f}({:.4f}+{:.4f})-{:.4f}({:.4f}+{:.4f}) " "sod-mse={:.4f}({:.4f}) gcn-mse={:.4f}({:.4f})".format( i, tr_num, loss.detach().item(), loss_fuse1.detach().item(), loss_fuse2.detach().item(), epoch_loss / (i + 1), epoch_loss1 / (i + 1), epoch_loss2 / (i + 1), mae, epoch_mae / (i + 1), mae2, epoch_mae2 / nb_data)) pass pass # 结果 avg_loss, avg_loss1, avg_loss2 = epoch_loss / tr_num, epoch_loss1 / tr_num, epoch_loss2 / tr_num avg_mae, avg_prec, avg_recall = epoch_mae / tr_num, epoch_prec / tr_num, epoch_recall / tr_num score = (1 + 0.3) * avg_prec * avg_recall / (0.3 * avg_prec + avg_recall) avg_mae2, avg_prec2, avg_recall2 = epoch_mae2 / nb_data, epoch_prec2 / nb_data, epoch_recall2 / nb_data score2 = (1 + 0.3) * avg_prec2 * avg_recall2 / (0.3 * avg_prec2 + avg_recall2) return avg_loss, avg_loss1, avg_loss2, avg_mae, score.max( ), avg_mae2, score2.max()
def load_model(self): if os.path.exists(Config.pn_dir): self.proto_net.load_state_dict(torch.load(Config.pn_dir)) Tools.print("load proto net success from {}".format(Config.pn_dir)) pass
class Config(object): gpu_id = "0,1,2,3" gpu_num = len(gpu_id.split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) num_workers = 32 ####################################################################################### ic_ratio = 1 ic_knn = 100 ic_out_dim = 2048 ic_val_freq = 10 # ic_resnet, ic_modify_head, ic_net_name = resnet18, False, "res18" ic_resnet, ic_modify_head, ic_net_name = resnet34, True, "res34_head" ic_learning_rate = 0.01 ic_train_epoch = 1200 ic_first_epoch, ic_t_epoch = 400, 200 ic_batch_size = 64 * 4 * gpu_num ic_adjust_learning_rate = RunnerTool.adjust_learning_rate1 ####################################################################################### ############################################################################################### fsl_num_way = 5 fsl_num_shot = 1 fsl_episode_size = 15 fsl_test_episode = 600 # fsl_matching_net, fsl_net_name, fsl_batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 96 fsl_matching_net, fsl_net_name, fsl_batch_size = ResNet12Small( avg_pool=True, drop_rate=0.1), "resnet12", 32 fsl_learning_rate = 0.01 fsl_batch_size = fsl_batch_size * gpu_num # fsl_val_freq = 5 # fsl_train_epoch = 200 # fsl_lr_schedule = [100, 150] # fsl_val_freq = 2 # fsl_train_epoch = 100 # fsl_lr_schedule = [50, 80] fsl_val_freq = 2 fsl_train_epoch = 50 fsl_lr_schedule = [30, 40] ############################################################################################### model_name = "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}".format( gpu_id.replace(",", ""), ic_net_name, ic_train_epoch, ic_batch_size, ic_out_dim, fsl_net_name, fsl_train_epoch, fsl_num_way, fsl_num_shot, fsl_batch_size) if "Linux" in platform.platform(): data_root = '/mnt/4T/Data/data/UFSL/tiered-imagenet' if not os.path.isdir(data_root): data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/tiered-imagenet' else: data_root = "F:\\data\\tiered-imagenet" ############################################################################################### # ic_batch_size = 16 # fsl_batch_size = 16 # ic_train_epoch = 2 # ic_first_epoch, ic_t_epoch = 1, 1 # ic_val_freq = 1 # fsl_train_epoch = 8 # fsl_lr_schedule = [4, 6] # fsl_test_episode = 20 # fsl_val_freq = 2 # data_root = os.path.join(data_root, "small") ############################################################################################### _root_path = "../tiered_imagenet/models_mn/two_ic_ufsl_2net_res_sgd_acc_duli_nete" mn_dir = Tools.new_dir("{}/{}_mn.pkl".format(_root_path, model_name)) ic_dir = Tools.new_dir("{}/{}_ic.pkl".format(_root_path, model_name)) # ic_dir_checkpoint = None # ic_dir_checkpoint = "../tiered_imagenet/models/ic_res_xx/3_resnet_18_64_2048_1_1900_300_200_False_ic.pkl" ic_dir_checkpoint = "../tiered_imagenet/models_mn/two_ic_ufsl_2net_res_sgd_acc_duli_nete/123_res34_head_1200_384_2048_conv4_100_5_1_288_ic.pkl" Tools.print(model_name) Tools.print(data_root) pass
def __init__(self, data_root_path, down_ratio=4, sp_size=4, train_print_freq=100, test_print_freq=50, root_ckpt_dir="./ckpt2/norm3", lr=None, num_workers=8, use_gpu=True, gpu_id="1", has_bn=True, normalize=True, residual=False, concat=True, weight_decay=0.0, is_sgd=False): self.train_print_freq = train_print_freq self.test_print_freq = test_print_freq self.device = gpu_setup(use_gpu=use_gpu, gpu_id=gpu_id) self.root_ckpt_dir = Tools.new_dir(root_ckpt_dir) self.train_dataset = MyDataset(data_root_path=data_root_path, is_train=True, down_ratio=down_ratio, sp_size=sp_size) self.test_dataset = MyDataset(data_root_path=data_root_path, is_train=False, down_ratio=down_ratio, sp_size=sp_size) self.train_loader = DataLoader( self.train_dataset, batch_size=1, shuffle=True, num_workers=num_workers, collate_fn=self.train_dataset.collate_fn) self.test_loader = DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=num_workers, collate_fn=self.test_dataset.collate_fn) self.model = MyGCNNet(has_bn=has_bn, normalize=normalize, residual=residual, concat=concat).to(self.device) parameters = filter(lambda p: p.requires_grad, self.model.parameters()) if is_sgd: self.lr_s = [[0, 0.01], [50, 0.001], [90, 0.0001] ] if lr is None else lr self.optimizer = torch.optim.SGD(parameters, lr=self.lr_s[0][1], momentum=0.9, weight_decay=weight_decay) else: self.lr_s = [[0, 0.001], [50, 0.0001], [90, 0.00001] ] if lr is None else lr self.optimizer = torch.optim.Adam(parameters, lr=self.lr_s[0][1], weight_decay=weight_decay) Tools.print("Total param: {} lr_s={} Optimizer={}".format( self._view_model_param(self.model), self.lr_s, self.optimizer)) self._print_network(self.model) self.loss_class = nn.BCELoss().to(self.device) pass
import os import time import numpy as np import tensorflow as tf from keras import layers import keras.backend as k from keras import callbacks from keras.models import Model from keras.optimizers import Adam from alisuretool.Tools import Tools import keras.backend.tensorflow_backend as ktf from sklearn.neighbors import NearestNeighbors IS_VIS = True VIS_DIR = Tools.new_dir("./data/test_2019_11_19") class SketchyData(object): def __init__(self, n_x=4096, n_y=4096, n_z=1024, n_d=2048): self.n_x = n_x self.n_y = n_y self.n_z = n_z self.n_d = n_d pass # 读取所有的数据 @staticmethod def _load_data(data_path_alisure="./data/alisure", data_path_ext="./data/ZSSBIR_data"): # image image_vgg_features = np.load(
def train(self): Tools.print() Tools.print("Training...") best_accuracy = 0.0 # Init Update try: self.ic_model.eval() Tools.print("Init label {} .......") self.produce_class.reset() with torch.no_grad(): for image, label, idx in tqdm(self.ic_train_loader): image, idx = RunnerTool.to_cuda(image), RunnerTool.to_cuda( idx) ic_out_logits, ic_out_l2norm = self.ic_model(image) self.produce_class.cal_label(ic_out_l2norm, idx) pass pass Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2)) finally: pass for epoch in range(1, 1 + Config.ic_train_epoch): self.ic_model.train() Tools.print() ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch, Config.ic_first_epoch, Config.ic_t_epoch, Config.ic_learning_rate) Tools.print('Epoch: [{}] ic_lr={}'.format(epoch, ic_lr)) all_loss = 0.0 self.produce_class.reset() for image, label, idx in tqdm(self.ic_train_loader): image, label, idx = RunnerTool.to_cuda( image), RunnerTool.to_cuda(label), RunnerTool.to_cuda(idx) ########################################################################### # 1 calculate features ic_out_logits, ic_out_l2norm = self.ic_model(image) # 2 calculate labels ic_targets = self.produce_class.get_label(idx) self.produce_class.cal_label(ic_out_l2norm, idx) # 3 loss loss = self.ic_loss(ic_out_logits, ic_targets) all_loss += loss.item() # 4 backward self.ic_model.zero_grad() loss.backward() self.ic_model_optim.step() ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f}".format( epoch, all_loss / len(self.ic_train_loader))) Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2)) ########################################################################### ########################################################################### # Val if epoch % Config.ic_val_freq == 0: self.ic_model.eval() val_accuracy = self.test_tool_ic.val(epoch=epoch) if val_accuracy > best_accuracy: best_accuracy = val_accuracy torch.save(self.ic_model.state_dict(), Config.ic_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
_data_root_path = "/mnt/4T/ALISURE/DUTS" _train_print_freq = 1000 _test_print_freq = 1000 _num_workers = 10 _use_gpu = True # _gpu_id = "0" # _gpu_id = "1" _gpu_id = "2" # _gpu_id = "3" _epochs = 30 # Super Param Group 1 _is_sgd = False _weight_decay = 5e-4 # _lr = [[0, 5e-05], [20, 5e-06]] _lr = [[0, 1e-5], [20, 1e-6]] _sp_size, _down_ratio = 4, 4 _root_ckpt_dir = "./ckpt/PYG_ResNet1_NoBN/{}".format(_gpu_id) Tools.print("epochs:{} ckpt:{} sp size:{} down_ratio:{} workers:{} gpu:{} is_sgd:{} weight_decay:{}".format( _epochs, _root_ckpt_dir, _sp_size, _down_ratio, _num_workers, _gpu_id, _is_sgd, _weight_decay)) runner = RunnerSPE(data_root_path=_data_root_path, root_ckpt_dir=_root_ckpt_dir, sp_size=_sp_size, is_sgd=_is_sgd, lr=_lr, down_ratio=_down_ratio, weight_decay=_weight_decay, train_print_freq=_train_print_freq, test_print_freq=_test_print_freq, num_workers=_num_workers, use_gpu=_use_gpu, gpu_id=_gpu_id) runner.train(_epochs, start_epoch=0) pass
def train_mlc(self, start_epoch=0, model_file_name=None): if model_file_name is not None: Tools.print("Load model form {}".format(model_file_name), txt_path=self.config.mlc_save_result_txt) self.load_model(model_file_name) pass self.eval_mlc(epoch=0) for epoch in range(start_epoch, self.config.mlc_epoch_num): Tools.print() self._adjust_learning_rate( self.optimizer, epoch, lr=self.config.mlc_lr, change_epoch=self.config.mlc_change_epoch) Tools.print('Epoch:{:03d}, lr={:.6f}'.format( epoch, self.optimizer.param_groups[0]['lr']), txt_path=self.config.mlc_save_result_txt) ########################################################################### # 1 训练模型 all_loss = 0.0 self.net.train() self.dataset_mlc_train.reset() for i, (inputs, labels) in tqdm(enumerate(self.data_loader_mlc_train), total=len(self.data_loader_mlc_train)): inputs, labels = inputs.type( torch.FloatTensor).cuda(), labels.cuda() self.optimizer.zero_grad() result = self.net(inputs) loss = self.bce_loss(result, labels) loss.backward() self.optimizer.step() all_loss += loss.item() pass ########################################################################### Tools.print("[E:{:3d}/{:3d}] mlc loss:{:.4f}".format( epoch, self.config.mlc_epoch_num, all_loss / len(self.data_loader_mlc_train)), txt_path=self.config.mlc_save_result_txt) # 2 保存模型 if epoch % self.config.mlc_save_epoch_freq == 0: Tools.print() save_file_name = Tools.new_dir( os.path.join(self.config.mlc_model_dir, "mlc_{}.pth".format(epoch))) torch.save(self.net.state_dict(), save_file_name) Tools.print("Save Model to {}".format(save_file_name), txt_path=self.config.mlc_save_result_txt) Tools.print() pass ########################################################################### ########################################################################### # 3 评估模型 if epoch % self.config.mlc_eval_epoch_freq == 0: self.eval_mlc(epoch=epoch) pass ########################################################################### pass # Final Save Tools.print() save_file_name = Tools.new_dir( os.path.join(self.config.mlc_model_dir, "mlc_final_{}.pth".format(self.config.mlc_epoch_num))) torch.save(self.net.state_dict(), save_file_name) Tools.print("Save Model to {}".format(save_file_name), txt_path=self.config.mlc_save_result_txt) Tools.print() self.eval_mlc(epoch=self.config.mlc_epoch_num) pass
def main(self): tsne = TSNE(n_components=2) Tools.print("begin to fit_transform {}".format(self.config.result_png)) if os.path.exists(self.config.result_pkl): Tools.print("exist pkl, and now to load") result = Tools.read_from_pkl(self.config.result_pkl) else: Tools.print("not exist pkl, and now to fit") data, label = self.deal_feature() fit = tsne.fit_transform(data) result = {"fit": fit, "label": label} Tools.write_to_pkl(self.config.result_pkl, result) pass Tools.print("begin to embedding") fig = self.plot_embedding(result["fit"], result["label"]) Tools.print("begin to save") plt.savefig(self.config.result_png) # plt.show(fig) pass
Tools.new_dir(one.replace("train_final", "train_final_resize"))) pass pass all_image = glob(os.path.join(cam_path, "**/*.png"), recursive=True) split_id = 1000 pools = multiprocessing.Pool(processes=multiprocessing.cpu_count()) for i in range(split_id + 1): now_image = all_image[len(all_image) // split_id * i:len(all_image) // split_id * (i + 1)] pools.apply_async(main, args=(i, now_image)) pass Tools.print("begin") pools.close() pools.join() Tools.print("over") # Check for one in tqdm(all_image): if not os.path.exists(one.replace("train_final", "train_final_resize")): im = Image.open( one.replace( cam_path, os.path.join(voc12_root, "ILSVRC2017_DET/ILSVRC/Data/DET")).replace( ".png", ".JPEG")) Image.open(one).resize(im.size, resample=Image.NEAREST).save( Tools.new_dir(one.replace("train_final", "train_final_resize")))
import os import numpy as np from glob import glob from tqdm import tqdm from alisuretool.Tools import Tools from deal_data_0_global_info import get_data_root_path, get_project_path if __name__ == '__main__': data_root = get_data_root_path() image_info_path = os.path.join(data_root, "deal", "image_info_list2.pkl") person_pkl = os.path.join(data_root, "deal", "person2.pkl") result_image_info_path = os.path.join(data_root, "deal", "image_info_list_change_person2.pkl") image_info_list = Tools.read_from_pkl(image_info_path) person_info_list = Tools.read_from_pkl(person_pkl) result_image_info_list = [] for i, (image_info, person_info) in tqdm(enumerate(zip(image_info_list, person_info_list)), total=len(image_info_list)): if not os.path.exists(image_info["image_path"]) or image_info["image_path"] != person_info[1]: Tools.print(image_info["image_path"]) pass image_label = list(set([one[2] for one in image_info["object"]]+ ([124] if person_info[0] == 1 else []))) image_path = image_info["image_path"] result_image_info_list.append([image_label, image_path]) pass Tools.write_to_pkl(_path=result_image_info_path, _data=result_image_info_list) pass
def train(self): Tools.print() Tools.print("Training...") # Init Update try: self.feature_encoder.eval() self.relation_network.eval() self.ic_model.eval() Tools.print("Init label {} .......") self.produce_class.reset() for task_data, task_labels, task_index, task_ok in tqdm( self.task_train_loader): ic_labels = RunnerTool.to_cuda(task_index[:, -1]) task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) relations, query_features = self.compare_fsl(task_data) ic_out_logits, ic_out_l2norm = self.ic_model(query_features) self.produce_class.cal_label(ic_out_l2norm, ic_labels) pass Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2)) finally: pass for epoch in range(Config.train_epoch): self.feature_encoder.train() self.relation_network.train() self.ic_model.train() Tools.print() self.produce_class.reset() Tools.print(self.task_train.classes) is_ok_total, is_ok_acc = 0, 0 all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0 for task_data, task_labels, task_index, task_ok in tqdm( self.task_train_loader): ic_labels = RunnerTool.to_cuda(task_index[:, -1]) task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) ########################################################################### # 1 calculate features relations, query_features = self.compare_fsl(task_data) ic_out_logits, ic_out_l2norm = self.ic_model(query_features) # 2 ic_targets = self.produce_class.get_label(ic_labels) self.produce_class.cal_label(ic_out_l2norm, ic_labels) # 3 loss loss_fsl = self.fsl_loss(relations, task_labels) * Config.loss_fsl_ratio loss_ic = self.ic_loss(ic_out_logits, ic_targets) * Config.loss_ic_ratio loss = loss_fsl + loss_ic all_loss += loss.item() all_loss_fsl += loss_fsl.item() all_loss_ic += loss_ic.item() # 4 backward self.feature_encoder.zero_grad() self.relation_network.zero_grad() self.ic_model.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( self.feature_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_( self.relation_network.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5) self.feature_encoder_optim.step() self.relation_network_optim.step() self.ic_model_optim.step() # is ok is_ok_acc += torch.sum(torch.cat(task_ok)) is_ok_total += torch.prod( torch.tensor(torch.cat(task_ok).shape)) ########################################################################### pass ########################################################################### # print Tools.print( "{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} ok:{:.3f}({}/{}) lr:{}". format(epoch + 1, all_loss / len(self.task_train_loader), all_loss_fsl / len(self.task_train_loader), all_loss_ic / len(self.task_train_loader), int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total, self.feature_encoder_scheduler.get_last_lr())) Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2)) self.feature_encoder_scheduler.step() self.relation_network_scheduler.step() self.ic_model_scheduler.step() ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: self.feature_encoder.eval() self.relation_network.eval() self.ic_model.eval() self.test_tool_ic.val(epoch=epoch) val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.feature_encoder.state_dict(), Config.fe_dir) torch.save(self.relation_network.state_dict(), Config.rn_dir) torch.save(self.ic_model.state_dict(), Config.ic_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
result_name = "08-14-2" cap = cv2.VideoCapture(1) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) background_sub_tractor_mog2 = cv2.createBackgroundSubtractorMOG2() count = 0 while True: ret, frame = cap.read() if ret: mask = background_sub_tractor_mog2.apply(frame) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) if np.sum(mask // 125) > 1000: count += 1 Tools.print("{} {}".format(Tools.get_format_time(), count)) if count > 5: file_name = "{}_{}.png".format( Tools.get_format_time().replace(":", "_"), count) result_file_name = os.path.join(result_dir, result_name, file_name) cv2.imwrite(Tools.new_dir(result_file_name), frame) pass cv2.imshow('frame', frame) cv2.imshow('mask', mask) if cv2.waitKey(30) & 0xff == "q": break else: break
def load_model(self, model_file_name): self.model.load_state_dict(torch.load(model_file_name), strict=False) Tools.print('Load Model: {}'.format(model_file_name)) pass
def train(self): Tools.print() Tools.print("Training...") for epoch in range(Config.train_epoch): self.proto_net.train() Tools.print() all_loss, all_loss_triple, all_loss_mixup = 0.0, 0.0, 0.0 for task_tuple, inputs, task_index in tqdm(self.task_train_loader): batch_size, num, c, w, h = inputs.shape # beta = np.random.beta(1, 1, [batch_size, num]) # 64, 3 beta = np.zeros([batch_size, num]) + 0.5 # 64, 3 beta_lambda = np.hstack([beta, 1 - beta]) # 64, 6 beta_lambda_tile = np.tile(beta_lambda[..., None, None, None], [c, w, h]) inputs_1 = torch.cat( [inputs, inputs[:, 1:, ...], inputs[:, 0:1, ...]], dim=1) * beta_lambda_tile inputs_1 = (inputs_1[:, 0:num, ...] + inputs_1[:, num:, ...]).float() now_inputs = torch.cat([inputs, inputs_1], dim=1).view(-1, c, w, h) now_inputs = RunnerTool.to_cuda(now_inputs) # 1 calculate features net_out = self.proto_net(now_inputs) net_out = self.norm(net_out) _, out_c, out_w, out_h = net_out.shape z = net_out.view(batch_size, -1, out_c, out_w, out_h) # 2 calculate loss loss, loss_triple, loss_mixup = self.mixup_loss( z, beta_lambda=beta_lambda) all_loss += loss.item() all_loss_mixup += loss_mixup.item() all_loss_triple += loss_triple.item() # 3 backward self.proto_net.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5) self.proto_net_optim.step() ########################################################################### pass ########################################################################### # print Tools.print( "{:6} loss:{:.3f} triple:{:.3f} mixup:{:.3f} lr:{}".format( epoch + 1, all_loss / len(self.task_train_loader), all_loss_triple / len(self.task_train_loader), all_loss_mixup / len(self.task_train_loader), self.proto_net_scheduler.get_last_lr())) self.proto_net_scheduler.step() ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: Tools.print() Tools.print("Test {} {} .......".format( epoch, Config.model_name)) self.proto_net.eval() val_accuracy = self.test_tool.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.proto_net.state_dict(), Config.pn_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
ln3, = plt.plot(x_data, y_data3, color='blue', linewidth=2.0, linestyle='-') ln4, = plt.plot(x_data, y_data4, color='blue', linewidth=2.0, linestyle='-.') plt.legend(handles=[ ln1, ln2, ln3, ln4, ln1, ln2, ln3, ln4, ln1, ln2, ln3, ln4, ln1, ln2, ln3, ln4 ], labels=[ "AAA", "BBB", "AAA", "BBB", "AAA", "BBB", "AAA", "BBB", "AAA", "BBB", "AAA", "BBB", "AAA", "BBB", "AAA", "BBB" ], loc='best', ncol=2, fontsize=14) plt.grid(linestyle='--') plt.ylim(0.2, 1.0) plt.locator_params("y", nbins=10) # plt.xlabel('Shot', fontsize=16) # plt.ylabel('Accuracy', fontsize=16) plt.tick_params(labelsize=14) plt.subplots_adjust(top=0.96, bottom=0.06, left=0.08, right=0.98, hspace=0, wspace=0) plt.savefig(Tools.new_dir(os.path.join("plot", "shot", "demo_shot.pdf"))) plt.show()
def train(self, epochs, start_epoch=0): test_loss, test_loss1, test_loss2, test_mae, test_score, test_mae2, test_score2 = self.test() Tools.print('E:{:2d}, Test sod-mae-score={:.4f}-{:.4f} ' 'gcn-mae-score={:.4f}-{:.4f} loss={:.4f}({:.4f}+{:.4f})'.format( 0, test_mae, test_score, test_mae2, test_score2, test_loss, test_loss1, test_loss2)) for epoch in range(start_epoch, epochs): Tools.print() Tools.print("Start Epoch {}".format(epoch)) self._lr(epoch) Tools.print('Epoch:{:02d},lr={:.4f}'.format(epoch, self.optimizer.param_groups[0]['lr'])) (train_loss, train_loss1, train_loss2, train_mae, train_score, train_mae2, train_score2) = self._train_epoch() self._save_checkpoint(self.model, self.root_ckpt_dir, epoch) test_loss, test_loss1, test_loss2, test_mae, test_score, test_mae2, test_score2 = self.test() Tools.print('E:{:2d}, Train sod-mae-score={:.4f}-{:.4f} ' 'gcn-mae-score={:.4f}-{:.4f} loss={:.4f}({:.4f}+{:.4f})'.format( epoch, train_mae, train_score, train_mae2, train_score2, train_loss, train_loss1, train_loss2)) Tools.print('E:{:2d}, Test sod-mae-score={:.4f}-{:.4f} ' 'gcn-mae-score={:.4f}-{:.4f} loss={:.4f}({:.4f}+{:.4f})'.format( epoch, test_mae, test_score, test_mae2, test_score2, test_loss, test_loss1, test_loss2)) pass pass
def train(self): Tools.print() Tools.print("Training...") for epoch in range(1, 1 + Config.train_epoch): self.matching_net.train() Tools.print() all_loss = 0.0 for task_data, task_labels, task_index in tqdm(self.task_train_loader): task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels) # 1 calculate features predicts = self.matching(task_data) # 2 loss loss = self.loss(predicts, task_labels) all_loss += loss.item() # 3 backward self.matching_net.zero_grad() loss.backward() self.matching_net_optim.step() ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f} lr:{}".format( epoch, all_loss / len(self.task_train_loader), self.matching_net_scheduler.get_last_lr())) self.matching_net_scheduler.step() ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: Tools.print() Tools.print("Test {} {} .......".format(epoch, Config.model_name)) self.matching_net.eval() val_accuracy = self.test_tool.val(episode=epoch, is_print=True, has_test=False) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.matching_net.state_dict(), Config.mn_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def inference_ss(self, model_file_name=None, data_loader=None, save_path=None): if model_file_name is not None: Tools.print("Load model form {}".format(model_file_name), txt_path=self.config.ss_save_result_txt) self.load_model(model_file_name) pass final_save_path = Tools.new_dir("{}_final".format(save_path)) self.net.eval() metrics = StreamSegMetrics(self.config.ss_num_classes) with torch.no_grad(): for i, (inputs, labels, image_info_list) in tqdm(enumerate(data_loader), total=len(data_loader)): assert len(image_info_list) == 1 # 标签 max_size = 1000 size = Image.open(image_info_list[0]).size basename = os.path.basename(image_info_list[0]) final_name = os.path.join(final_save_path, basename.replace(".JPEG", ".png")) if os.path.exists(final_name): continue if size[0] < max_size and size[1] < max_size: targets = F.interpolate(torch.unsqueeze( labels[0].float().cuda(), dim=0), size=(size[1], size[0]), mode="nearest").detach().cpu() else: targets = F.interpolate(torch.unsqueeze(labels[0].float(), dim=0), size=(size[1], size[0]), mode="nearest") targets = targets[0].long().numpy() # 预测 outputs = 0 for input_index, input_one in enumerate(inputs): output_one = self.net(input_one.float().cuda()) if size[0] < max_size and size[1] < max_size: outputs += F.interpolate( output_one, size=(size[1], size[0]), mode="bilinear", align_corners=False).detach().cpu() else: outputs += F.interpolate(output_one.detach().cpu(), size=(size[1], size[0]), mode="bilinear", align_corners=False) pass pass outputs = outputs / len(inputs) preds = outputs.max(dim=1)[1].numpy() # 计算 metrics.update(targets, preds) if save_path: Image.open(image_info_list[0]).save( os.path.join(save_path, basename)) DataUtil.gray_to_color( np.asarray(targets[0], dtype=np.uint8)).save( os.path.join(save_path, basename.replace(".JPEG", "_l.png"))) DataUtil.gray_to_color(np.asarray( preds[0], dtype=np.uint8)).save( os.path.join(save_path, basename.replace(".JPEG", ".png"))) Image.fromarray(np.asarray( preds[0], dtype=np.uint8)).save(final_name) pass pass pass score = metrics.get_results() Tools.print("{}".format(metrics.to_str(score)), txt_path=self.config.ss_save_result_txt) return score
def train(self): Tools.print() Tools.print("Training...") best_accuracy = 0.0 for epoch in range(1, 1 + Config.fsl_train_epoch): self.matching_net.train() Tools.print() all_loss, is_ok_total, is_ok_acc = 0.0, 0, 0 for task_data, task_labels, task_index, task_ok in tqdm( self.task_train_loader): task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) ########################################################################### # 1 calculate features relations = self.matching(task_data) # 3 loss loss = self.fsl_loss(relations, task_labels) all_loss += loss.item() # 4 backward self.matching_net.zero_grad() loss.backward() self.matching_net_optim.step() # is ok is_ok_acc += torch.sum(torch.cat(task_ok)) is_ok_total += torch.prod( torch.tensor(torch.cat(task_ok).shape)) ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f} ok:{:.3f}({}/{}) lr:{}".format( epoch, all_loss / len(self.task_train_loader), int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total, self.matching_net_scheduler.get_last_lr())) self.matching_net_scheduler.step() ########################################################################### ########################################################################### # Val if epoch % Config.fsl_val_freq == 0: self.matching_net.eval() val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True, has_test=False) if val_accuracy > best_accuracy: best_accuracy = val_accuracy torch.save(self.matching_net.state_dict(), Config.mn_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def __init__(self): # self.gpu_id_1, self.gpu_id_4 = "0", "0, 1, 2, 3" # self.gpu_id_1, self.gpu_id_4 = "1", "0, 1, 2, 3" self.gpu_id_1, self.gpu_id_4 = "2", "0, 1, 2, 3" # self.gpu_id_1, self.gpu_id_4 = "3", "0, 1, 2, 3" # 流程控制 self.only_train_ss = False # 是否训练SS self.is_balance_data = True self.only_eval_ss = False # 是否评估SS self.only_inference_ss = True # 是否推理SS # 1 原始数据训练 # self.scales = (1.0, 0.5, 1.5) # self.model_file_name = "../../../WSS_Model_SS/4_DeepLabV3PlusResNet50_201_10_18_1_352/ss_final_10.pth" # self.eval_save_path = "../../../WSS_Model_SS_EVAL/4_DeepLabV3PlusResNet50_201_10_18_1_352/ss_final_10_scales" # 2 平衡数据训练 # self.scales = (1.0, 0.5, 1.5) # self.model_file_name = "../../../WSS_Model_SS/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8.pth" # self.eval_save_path = "../../../WSS_Model_SS_EVAL/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8_scales" # 3 平衡数据训练 # self.scales = (1.0, 0.5, 1.5, 2.0) # self.model_file_name = "../../../WSS_Model_SS/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8.pth" # self.eval_save_path = "../../../WSS_Model_SS_EVAL/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8_scales_4" # 4 平衡数据训练 # self.scales = (1.0, 0.75, 0.5, 1.25, 1.5, 1.75, 2.0) # self.model_file_name = "../../../WSS_Model_SS/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8.pth" # self.eval_save_path = "../../../WSS_Model_SS_EVAL/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8_scales_7" # 4 平衡数据训练 self.scales = (1.0, 0.75, 0.5, 1.25, 1.5, 1.75, 2.0) self.model_file_name = "../../../WSS_Model_SS/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_7.pth" self.eval_save_path = "../../../WSS_Model_SS_EVAL/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_7_scales_7" # self.model_file_name = "../../../WSS_Model_SS/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_final_10.pth" # self.eval_save_path = "../../../WSS_Model_SS_EVAL/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_10_scales_7" # 其他方法生成的伪标签 # self.label_path = "/mnt/4T/ALISURE/USS/WSS_CAM/cam/1_CAMNet_200_32_256_0.5" # self.label_path = "/mnt/4T/ALISURE/USS/WSS_CAM/cam_4/1_200_32_256_0.5" # self.label_path = "/mnt/4T/ALISURE/USS/WSS_CAM/cam_4/2_1_200_32_256_0.5" self.label_path = "/media/ubuntu/4T/ALISURE/USS/ConTa/pseudo_mask/result/2/sem_seg" self.ss_num_classes = 201 self.ss_epoch_num = 10 self.ss_milestones = [5, 8] self.ss_batch_size = 6 * (len(self.gpu_id_4.split(",")) - 1) self.ss_lr = 0.001 self.ss_save_epoch_freq = 1 self.ss_eval_epoch_freq = 1 # 图像大小 self.ss_size = 352 self.output_stride = 16 # 网络 self.Net = DeepLabV3Plus # self.arch, self.arch_name = deeplabv3_resnet50, "DeepLabV3PlusResNet50" self.arch, self.arch_name = deeplabv3plus_resnet101, "DeepLabV3PlusResNet101" self.data_root_path = self.get_data_root_path() os.environ["CUDA_VISIBLE_DEVICES"] = str( self.gpu_id_4) if self.only_train_ss else str(self.gpu_id_1) run_name = "7" self.model_name = "{}_{}_{}_{}_{}_{}_{}{}".format( run_name, self.arch_name, self.ss_num_classes, self.ss_epoch_num, self.ss_batch_size, self.ss_save_epoch_freq, self.ss_size, "_balance" if self.is_balance_data else "") Tools.print(self.model_name) self.ss_model_dir = "../../../WSS_Model_SS/{}".format(self.model_name) self.ss_save_result_txt = Tools.new_dir("{}/result.txt".format( self.ss_model_dir)) pass
_is_sgd = False _weight_decay = 5e-4 _lr = [[0, 5e-5], [20, 5e-6]] _improved = True _has_bn = True _has_residual = True _is_normalize = True _concat = True _sp_size, _down_ratio = 3, 8 _root_ckpt_dir = "./ckpt/PYG_ResNet_GCN_more/sigmoid_{}".format(_gpu_id) Tools.print( "epochs:{} ckpt:{} sp size:{} down_ratio:{} workers:{} gpu:{} has_residual:{} " "is_normalize:{} has_bn:{} improved:{} concat:{} is_sgd:{} weight_decay:{}" .format(_epochs, _root_ckpt_dir, _sp_size, _down_ratio, _num_workers, _gpu_id, _has_residual, _is_normalize, _has_bn, _improved, _concat, _is_sgd, _weight_decay)) runner = RunnerSPE(data_root_path=_data_root_path, root_ckpt_dir=_root_ckpt_dir, sp_size=_sp_size, is_sgd=_is_sgd, lr=_lr, residual=_has_residual, normalize=_is_normalize, down_ratio=_down_ratio, has_bn=_has_bn, concat=_concat, weight_decay=_weight_decay, train_print_freq=_train_print_freq,
def train_ss(self, start_epoch=0, model_file_name=None): if model_file_name is not None: Tools.print("Load model form {}".format(model_file_name), txt_path=self.config.ss_save_result_txt) self.load_model(model_file_name) pass # self.eval_ss(epoch=0) for epoch in range(start_epoch, self.config.ss_epoch_num): Tools.print() Tools.print('Epoch:{:2d}, lr={:.6f} lr2={:.6f}'.format( epoch, self.optimizer.param_groups[0]['lr'], self.optimizer.param_groups[1]['lr']), txt_path=self.config.ss_save_result_txt) ########################################################################### # 1 训练模型 all_loss = 0.0 self.net.train() if self.config.is_balance_data: self.dataset_ss_train.reset() pass for i, (inputs, labels) in tqdm(enumerate(self.data_loader_ss_train), total=len(self.data_loader_ss_train)): inputs, labels = inputs.float().cuda(), labels.long().cuda() self.optimizer.zero_grad() result = self.net(inputs) loss = self.ce_loss(result, labels) loss.backward() self.optimizer.step() all_loss += loss.item() if (i + 1) % (len(self.data_loader_ss_train) // 10) == 0: self.eval_ss(epoch=epoch) pass pass self.scheduler.step() ########################################################################### Tools.print("[E:{:3d}/{:3d}] ss loss:{:.4f}".format( epoch, self.config.ss_epoch_num, all_loss / len(self.data_loader_ss_train)), txt_path=self.config.ss_save_result_txt) ########################################################################### # 2 保存模型 if epoch % self.config.ss_save_epoch_freq == 0: Tools.print() save_file_name = Tools.new_dir( os.path.join(self.config.ss_model_dir, "ss_{}.pth".format(epoch))) torch.save(self.net.state_dict(), save_file_name) Tools.print("Save Model to {}".format(save_file_name), txt_path=self.config.ss_save_result_txt) Tools.print() pass ########################################################################### ########################################################################### # 3 评估模型 if epoch % self.config.ss_eval_epoch_freq == 0: self.eval_ss(epoch=epoch) pass ########################################################################### pass # Final Save Tools.print() save_file_name = Tools.new_dir( os.path.join(self.config.ss_model_dir, "ss_final_{}.pth".format(self.config.ss_epoch_num))) torch.save(self.net.state_dict(), save_file_name) Tools.print("Save Model to {}".format(save_file_name), txt_path=self.config.ss_save_result_txt) Tools.print() self.eval_ss(epoch=self.config.ss_epoch_num) pass
def load_model(self, model_file_name): ckpt = torch.load(model_file_name, map_location=self.device) self.model.load_state_dict(ckpt, strict=False) Tools.print('Load Model: {}'.format(model_file_name)) pass
def load_model(self): if os.path.exists(Config.ic_dir): self.ic_model.load_state_dict(torch.load(Config.ic_dir)) Tools.print("load ic model success from {}".format(Config.ic_dir)) pass
def test(self, model_file=None, is_train_loader=False): self.model.eval() if model_file: self.load_model(model_file_name=model_file) Tools.print() th_num = 25 # 统计 epoch_test_loss, epoch_test_loss1, epoch_test_loss2, nb_data = 0, 0, 0, 0 epoch_test_mae, epoch_test_mae2 = 0.0, 0.0 epoch_test_prec, epoch_test_recall = np.zeros( shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6 epoch_test_prec2, epoch_test_recall2 = np.zeros( shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6 loader = self.train_loader if is_train_loader else self.test_loader tr_num = len(loader) with torch.no_grad(): for i, (images, _, labels_sod, batched_graph, batched_pixel_graph, segments, _, _) in enumerate(loader): # Data images = images.float().to(self.device) labels = batched_graph.y.to(self.device) labels_sod = torch.unsqueeze(torch.Tensor(labels_sod), dim=1).to(self.device) batched_graph.batch = batched_graph.batch.to(self.device) batched_graph.edge_index = batched_graph.edge_index.to( self.device) batched_pixel_graph.batch = batched_pixel_graph.batch.to( self.device) batched_pixel_graph.edge_index = batched_pixel_graph.edge_index.to( self.device) batched_pixel_graph.data_where = batched_pixel_graph.data_where.to( self.device) _, gcn_logits_sigmoid, _, _, sod_logits_sigmoid = self.model.forward( images, batched_graph, batched_pixel_graph) loss1 = self.loss_bce(gcn_logits_sigmoid, labels) loss2 = self.loss_bce(sod_logits_sigmoid, labels_sod) loss = loss1 + loss2 labels_val = labels.cpu().detach().numpy() labels_sod_val = labels_sod.cpu().detach().numpy() gcn_logits_sigmoid_val = gcn_logits_sigmoid.cpu().detach( ).numpy() sod_logits_sigmoid_val = sod_logits_sigmoid.cpu().detach( ).numpy() # Stat nb_data += images.size(0) epoch_test_loss += loss.detach().item() epoch_test_loss1 += loss1.detach().item() epoch_test_loss2 += loss2.detach().item() # cal 1 mae = self._eval_mae(sod_logits_sigmoid_val, labels_sod_val) prec, recall = self._eval_pr(sod_logits_sigmoid_val, labels_sod_val, th_num) epoch_test_mae += mae epoch_test_prec += prec epoch_test_recall += recall # cal 2 mae2 = self._eval_mae(gcn_logits_sigmoid_val, labels_val) prec2, recall2 = self._eval_pr(gcn_logits_sigmoid_val, labels_val, th_num) epoch_test_mae2 += mae2 epoch_test_prec2 += prec2 epoch_test_recall2 += recall2 # Print if i % self.test_print_freq == 0: Tools.print( "{:4d}-{:4d} loss={:.4f}({:.4f}+{:.4f})-{:.4f}({:.4f}+{:.4f}) " "sod-mse={:.4f}({:.4f}) gcn-mse={:.4f}({:.4f})".format( i, len(loader), loss.detach().item(), loss1.detach().item(), loss2.detach().item(), epoch_test_loss / (i + 1), epoch_test_loss1 / (i + 1), epoch_test_loss2 / (i + 1), mae, epoch_test_mae / (i + 1), mae2, epoch_test_mae2 / nb_data)) pass pass pass # 结果1 avg_loss, avg_loss1, avg_loss2 = epoch_test_loss / tr_num, epoch_test_loss1 / tr_num, epoch_test_loss2 / tr_num avg_mae, avg_prec, avg_recall = epoch_test_mae / tr_num, epoch_test_prec / tr_num, epoch_test_recall / tr_num score = (1 + 0.3) * avg_prec * avg_recall / (0.3 * avg_prec + avg_recall) avg_mae2, avg_prec2, avg_recall2 = epoch_test_mae2 / nb_data, epoch_test_prec2 / nb_data, epoch_test_recall2 / nb_data score2 = (1 + 0.3) * avg_prec2 * avg_recall2 / (0.3 * avg_prec2 + avg_recall2) return avg_loss, avg_loss1, avg_loss2, avg_mae, score.max( ), avg_mae2, score2.max()
# _data_root_path = '/home/ubuntu/ALISURE/data/mnist' _data_root_path = "/private/alishuo/mnist" _root_ckpt_dir = "./ckpt2/dgl/4_DGL_CONV-mnist3-Adam/{}".format( "GCNNet-C1") _batch_size = 64 _image_size = 28 _sp_size = 4 _train_print_freq = 100 _test_print_freq = 100 _num_workers = 6 _use_gpu = True _gpu_id = "0" # _gpu_id = "1" Tools.print( "ckpt:{} batch size:{} image size:{} sp size:{} workers:{} gpu:{}". format(_root_ckpt_dir, _batch_size, _image_size, _sp_size, _num_workers, _gpu_id)) runner = RunnerSPE(data_root_path=_data_root_path, root_ckpt_dir=_root_ckpt_dir, batch_size=_batch_size, image_size=_image_size, sp_size=_sp_size, train_print_freq=_train_print_freq, test_print_freq=_test_print_freq, num_workers=_num_workers, use_gpu=_use_gpu, gpu_id=_gpu_id) runner.train(150) pass
def train(self): Tools.print() Tools.print("Training...") for epoch in range(1, 1 + Config.train_epoch): self.proto_net.train() Tools.print() all_loss = 0.0 pn_lr = self.adjust_learning_rate(self.proto_net_optim, epoch, Config.first_epoch, Config.t_epoch, Config.learning_rate) Tools.print('Epoch: [{}] pn_lr={}'.format(epoch, pn_lr)) for task_data, task_labels, task_index in tqdm( self.task_train_loader): task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) # 1 calculate features out = self.proto(task_data) # 2 loss loss = self.loss(out, task_labels) all_loss += loss.item() # 3 backward self.proto_net.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5) self.proto_net_optim.step() ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f}".format( epoch, all_loss / len(self.task_train_loader))) ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: Tools.print() Tools.print("Test {} {} .......".format( epoch, Config.model_name)) self.proto_net.eval() val_accuracy = self.test_tool.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.proto_net.state_dict(), Config.pn_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def train_val_pipeline(self): t0, per_epoch_time = time.time(), [] epoch_train_losses, epoch_val_losses, epoch_train_accs, epoch_val_accs = [], [], [], [] for epoch in range(self.params.epochs): start = time.time() Tools.print() Tools.print("Start Epoch {}".format(epoch)) epoch_train_loss, epoch_train_acc = self.train_epoch( self.train_loader) epoch_val_loss, epoch_val_acc = self.evaluate_network( self.val_loader) epoch_test_loss, epoch_test_acc = self.evaluate_network( self.test_loader) self.scheduler.step(epoch_val_loss) self.save_checkpoint(self.model, self.params.root_ckpt_dir, epoch) epoch_train_losses.append(epoch_train_loss) epoch_val_losses.append(epoch_val_loss) epoch_train_accs.append(epoch_train_acc) epoch_val_accs.append(epoch_val_acc) self.writer.add_scalar('train/_loss', epoch_train_loss, epoch) self.writer.add_scalar('val/_loss', epoch_val_loss, epoch) self.writer.add_scalar('train/_acc', epoch_train_acc, epoch) self.writer.add_scalar('val/_acc', epoch_val_acc, epoch) self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]['lr'], epoch) per_epoch_time.append(time.time() - start) Tools.print( "time={:.4f}, lr={:.4f}, loss={:.4f}/{:.4f}/{:.4f}, acc={:.4f}/{:.4f}/{:.4f}" .format(time.time() - start, self.optimizer.param_groups[0]['lr'], epoch_train_loss, epoch_val_loss, epoch_test_loss, epoch_train_acc, epoch_val_acc, epoch_test_acc)) # Stop training if self.optimizer.param_groups[0]['lr'] < self.params.min_lr: Tools.print() Tools.print("\n!! LR EQUAL TO MIN LR SET.") break if time.time() - t0 > self.params.max_time * 3600: Tools.print() Tools.print( "Max_time for training elapsed {:.2f} hours, so stopping". format(self.params.max_time)) break pass _, val_acc = self.evaluate_network(self.val_loader) _, test_acc = self.evaluate_network(self.test_loader) _, train_acc = self.evaluate_network(self.train_loader) Tools.print() Tools.print("Val Accuracy: {:.4f}".format(val_acc)) Tools.print("Test Accuracy: {:.4f}".format(test_acc)) Tools.print("Train Accuracy: {:.4f}".format(train_acc)) Tools.print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0)) Tools.print("AVG TIME PER EPOCH: {:.4f}s".format( np.mean(per_epoch_time))) self.writer.close() pass