コード例 #1
0
    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        try:
            self.proto_net.eval()
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            for task_data, task_labels, task_index in tqdm(self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)
                log_p_y, query_features = self.proto(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2))
        finally:
            pass

        for epoch in range(Config.train_epoch):
            self.proto_net.train()
            self.ic_model.train()

            Tools.print()
            self.produce_class.reset()
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index in tqdm(self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                log_p_y, query_features = self.proto(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = -(log_p_y * task_labels).sum() / task_labels.sum() * Config.loss_fsl_ratio
                loss_ic = self.ic_loss(ic_out_logits, ic_targets) * Config.loss_ic_ratio
                loss = loss_fsl + loss_ic
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.proto_net.zero_grad()
                self.ic_model.zero_grad()
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5)
                # torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5)
                self.proto_net_optim.step()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} lr:{}".format(
                epoch + 1, all_loss / len(self.task_train_loader), all_loss_fsl / len(self.task_train_loader),
                all_loss_ic / len(self.task_train_loader), self.proto_net_scheduler.get_last_lr()))
            Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2))
            self.proto_net_scheduler.step()
            self.ic_model_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.proto_net.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True)

                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
コード例 #2
0
    def _train_epoch(self):
        self.model.eval()

        # 统计
        th_num = 25
        epoch_loss, epoch_loss1, epoch_loss2, nb_data = 0, 0, 0, 0
        epoch_mae, epoch_prec, epoch_recall = 0.0, np.zeros(
            shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6
        epoch_mae2, epoch_prec2, epoch_recall2 = 0.0, np.zeros(
            shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6

        # Run
        iter_size = 10
        self.model.zero_grad()
        tr_num = len(self.train_loader)
        for i, (images, _, labels_sod, batched_graph, batched_pixel_graph,
                segments, _, _) in enumerate(self.train_loader):
            # Data
            images = images.float().to(self.device)
            labels = batched_graph.y.to(self.device)
            labels_sod = torch.unsqueeze(torch.Tensor(labels_sod),
                                         dim=1).to(self.device)
            batched_graph.batch = batched_graph.batch.to(self.device)
            batched_graph.edge_index = batched_graph.edge_index.to(self.device)

            batched_pixel_graph.batch = batched_pixel_graph.batch.to(
                self.device)
            batched_pixel_graph.edge_index = batched_pixel_graph.edge_index.to(
                self.device)
            batched_pixel_graph.data_where = batched_pixel_graph.data_where.to(
                self.device)

            gcn_logits, gcn_logits_sigmoid, _, sod_logits, sod_logits_sigmoid = self.model.forward(
                images, batched_graph, batched_pixel_graph)

            loss_fuse1 = F.binary_cross_entropy_with_logits(sod_logits,
                                                            labels_sod,
                                                            reduction='sum')
            loss_fuse2 = F.binary_cross_entropy_with_logits(gcn_logits,
                                                            labels,
                                                            reduction='sum')
            # loss = (loss_fuse1 + loss_fuse2) / iter_size
            loss = loss_fuse1 / iter_size + 2 * loss_fuse2
            # loss = loss_fuse1 / iter_size

            loss.backward()

            if (i + 1) % iter_size == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                pass

            labels_val = labels.cpu().detach().numpy()
            labels_sod_val = labels_sod.cpu().detach().numpy()
            gcn_logits_sigmoid_val = gcn_logits_sigmoid.cpu().detach().numpy()
            sod_logits_sigmoid_val = sod_logits_sigmoid.cpu().detach().numpy()

            # Stat
            nb_data += images.size(0)
            epoch_loss += loss.detach().item()
            epoch_loss1 += loss_fuse1.detach().item()
            epoch_loss2 += loss_fuse2.detach().item()

            # cal 1
            mae = self._eval_mae(sod_logits_sigmoid_val, labels_sod_val)
            prec, recall = self._eval_pr(sod_logits_sigmoid_val,
                                         labels_sod_val, th_num)
            epoch_mae += mae
            epoch_prec += prec
            epoch_recall += recall

            # cal 2
            mae2 = self._eval_mae(gcn_logits_sigmoid_val, labels_val)
            prec2, recall2 = self._eval_pr(gcn_logits_sigmoid_val, labels_val,
                                           th_num)
            epoch_mae2 += mae2
            epoch_prec2 += prec2
            epoch_recall2 += recall2

            # Print
            if i % self.train_print_freq == 0:
                Tools.print(
                    "{:4d}-{:4d} loss={:.4f}({:.4f}+{:.4f})-{:.4f}({:.4f}+{:.4f}) "
                    "sod-mse={:.4f}({:.4f}) gcn-mse={:.4f}({:.4f})".format(
                        i, tr_num,
                        loss.detach().item(),
                        loss_fuse1.detach().item(),
                        loss_fuse2.detach().item(), epoch_loss / (i + 1),
                        epoch_loss1 / (i + 1), epoch_loss2 / (i + 1), mae,
                        epoch_mae / (i + 1), mae2, epoch_mae2 / nb_data))
                pass
            pass

        # 结果
        avg_loss, avg_loss1, avg_loss2 = epoch_loss / tr_num, epoch_loss1 / tr_num, epoch_loss2 / tr_num

        avg_mae, avg_prec, avg_recall = epoch_mae / tr_num, epoch_prec / tr_num, epoch_recall / tr_num
        score = (1 + 0.3) * avg_prec * avg_recall / (0.3 * avg_prec +
                                                     avg_recall)
        avg_mae2, avg_prec2, avg_recall2 = epoch_mae2 / nb_data, epoch_prec2 / nb_data, epoch_recall2 / nb_data
        score2 = (1 + 0.3) * avg_prec2 * avg_recall2 / (0.3 * avg_prec2 +
                                                        avg_recall2)

        return avg_loss, avg_loss1, avg_loss2, avg_mae, score.max(
        ), avg_mae2, score2.max()
コード例 #3
0
 def load_model(self):
     if os.path.exists(Config.pn_dir):
         self.proto_net.load_state_dict(torch.load(Config.pn_dir))
         Tools.print("load proto net success from {}".format(Config.pn_dir))
     pass
class Config(object):
    gpu_id = "0,1,2,3"
    gpu_num = len(gpu_id.split(","))
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    num_workers = 32

    #######################################################################################
    ic_ratio = 1
    ic_knn = 100
    ic_out_dim = 2048
    ic_val_freq = 10

    # ic_resnet, ic_modify_head, ic_net_name = resnet18, False, "res18"
    ic_resnet, ic_modify_head, ic_net_name = resnet34, True, "res34_head"

    ic_learning_rate = 0.01
    ic_train_epoch = 1200
    ic_first_epoch, ic_t_epoch = 400, 200
    ic_batch_size = 64 * 4 * gpu_num

    ic_adjust_learning_rate = RunnerTool.adjust_learning_rate1
    #######################################################################################

    ###############################################################################################
    fsl_num_way = 5
    fsl_num_shot = 1

    fsl_episode_size = 15
    fsl_test_episode = 600

    # fsl_matching_net, fsl_net_name, fsl_batch_size = MatchingNet(hid_dim=64, z_dim=64), "conv4", 96
    fsl_matching_net, fsl_net_name, fsl_batch_size = ResNet12Small(
        avg_pool=True, drop_rate=0.1), "resnet12", 32

    fsl_learning_rate = 0.01
    fsl_batch_size = fsl_batch_size * gpu_num

    # fsl_val_freq = 5
    # fsl_train_epoch = 200
    # fsl_lr_schedule = [100, 150]

    # fsl_val_freq = 2
    # fsl_train_epoch = 100
    # fsl_lr_schedule = [50, 80]
    fsl_val_freq = 2
    fsl_train_epoch = 50
    fsl_lr_schedule = [30, 40]
    ###############################################################################################

    model_name = "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
        gpu_id.replace(",", ""), ic_net_name, ic_train_epoch, ic_batch_size,
        ic_out_dim, fsl_net_name, fsl_train_epoch, fsl_num_way, fsl_num_shot,
        fsl_batch_size)

    if "Linux" in platform.platform():
        data_root = '/mnt/4T/Data/data/UFSL/tiered-imagenet'
        if not os.path.isdir(data_root):
            data_root = '/media/ubuntu/4T/ALISURE/Data/UFSL/tiered-imagenet'
    else:
        data_root = "F:\\data\\tiered-imagenet"

    ###############################################################################################
    # ic_batch_size = 16
    # fsl_batch_size = 16
    # ic_train_epoch = 2
    # ic_first_epoch, ic_t_epoch = 1, 1
    # ic_val_freq = 1
    # fsl_train_epoch = 8
    # fsl_lr_schedule = [4, 6]
    # fsl_test_episode = 20
    # fsl_val_freq = 2
    # data_root = os.path.join(data_root, "small")
    ###############################################################################################

    _root_path = "../tiered_imagenet/models_mn/two_ic_ufsl_2net_res_sgd_acc_duli_nete"
    mn_dir = Tools.new_dir("{}/{}_mn.pkl".format(_root_path, model_name))
    ic_dir = Tools.new_dir("{}/{}_ic.pkl".format(_root_path, model_name))
    # ic_dir_checkpoint = None
    # ic_dir_checkpoint = "../tiered_imagenet/models/ic_res_xx/3_resnet_18_64_2048_1_1900_300_200_False_ic.pkl"
    ic_dir_checkpoint = "../tiered_imagenet/models_mn/two_ic_ufsl_2net_res_sgd_acc_duli_nete/123_res34_head_1200_384_2048_conv4_100_5_1_288_ic.pkl"

    Tools.print(model_name)
    Tools.print(data_root)
    pass
コード例 #5
0
    def __init__(self,
                 data_root_path,
                 down_ratio=4,
                 sp_size=4,
                 train_print_freq=100,
                 test_print_freq=50,
                 root_ckpt_dir="./ckpt2/norm3",
                 lr=None,
                 num_workers=8,
                 use_gpu=True,
                 gpu_id="1",
                 has_bn=True,
                 normalize=True,
                 residual=False,
                 concat=True,
                 weight_decay=0.0,
                 is_sgd=False):
        self.train_print_freq = train_print_freq
        self.test_print_freq = test_print_freq

        self.device = gpu_setup(use_gpu=use_gpu, gpu_id=gpu_id)
        self.root_ckpt_dir = Tools.new_dir(root_ckpt_dir)

        self.train_dataset = MyDataset(data_root_path=data_root_path,
                                       is_train=True,
                                       down_ratio=down_ratio,
                                       sp_size=sp_size)
        self.test_dataset = MyDataset(data_root_path=data_root_path,
                                      is_train=False,
                                      down_ratio=down_ratio,
                                      sp_size=sp_size)

        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=1,
            shuffle=True,
            num_workers=num_workers,
            collate_fn=self.train_dataset.collate_fn)
        self.test_loader = DataLoader(self.test_dataset,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=num_workers,
                                      collate_fn=self.test_dataset.collate_fn)

        self.model = MyGCNNet(has_bn=has_bn,
                              normalize=normalize,
                              residual=residual,
                              concat=concat).to(self.device)
        parameters = filter(lambda p: p.requires_grad, self.model.parameters())
        if is_sgd:
            self.lr_s = [[0, 0.01], [50, 0.001], [90, 0.0001]
                         ] if lr is None else lr
            self.optimizer = torch.optim.SGD(parameters,
                                             lr=self.lr_s[0][1],
                                             momentum=0.9,
                                             weight_decay=weight_decay)
        else:
            self.lr_s = [[0, 0.001], [50, 0.0001], [90, 0.00001]
                         ] if lr is None else lr
            self.optimizer = torch.optim.Adam(parameters,
                                              lr=self.lr_s[0][1],
                                              weight_decay=weight_decay)

        Tools.print("Total param: {} lr_s={} Optimizer={}".format(
            self._view_model_param(self.model), self.lr_s, self.optimizer))
        self._print_network(self.model)

        self.loss_class = nn.BCELoss().to(self.device)
        pass
コード例 #6
0
import os
import time
import numpy as np
import tensorflow as tf
from keras import layers
import keras.backend as k
from keras import callbacks
from keras.models import Model
from keras.optimizers import Adam
from alisuretool.Tools import Tools
import keras.backend.tensorflow_backend as ktf
from sklearn.neighbors import NearestNeighbors

IS_VIS = True
VIS_DIR = Tools.new_dir("./data/test_2019_11_19")


class SketchyData(object):
    def __init__(self, n_x=4096, n_y=4096, n_z=1024, n_d=2048):
        self.n_x = n_x
        self.n_y = n_y
        self.n_z = n_z
        self.n_d = n_d
        pass

    # 读取所有的数据
    @staticmethod
    def _load_data(data_path_alisure="./data/alisure",
                   data_path_ext="./data/ZSSBIR_data"):
        # image
        image_vgg_features = np.load(
    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        # Init Update
        try:
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            with torch.no_grad():
                for image, label, idx in tqdm(self.ic_train_loader):
                    image, idx = RunnerTool.to_cuda(image), RunnerTool.to_cuda(
                        idx)
                    ic_out_logits, ic_out_l2norm = self.ic_model(image)
                    self.produce_class.cal_label(ic_out_l2norm, idx)
                    pass
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2))
        finally:
            pass

        for epoch in range(1, 1 + Config.ic_train_epoch):
            self.ic_model.train()

            Tools.print()
            ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch,
                                              Config.ic_first_epoch,
                                              Config.ic_t_epoch,
                                              Config.ic_learning_rate)
            Tools.print('Epoch: [{}] ic_lr={}'.format(epoch, ic_lr))

            all_loss = 0.0
            self.produce_class.reset()
            for image, label, idx in tqdm(self.ic_train_loader):
                image, label, idx = RunnerTool.to_cuda(
                    image), RunnerTool.to_cuda(label), RunnerTool.to_cuda(idx)

                ###########################################################################
                # 1 calculate features
                ic_out_logits, ic_out_l2norm = self.ic_model(image)

                # 2 calculate labels
                ic_targets = self.produce_class.get_label(idx)
                self.produce_class.cal_label(ic_out_l2norm, idx)

                # 3 loss
                loss = self.ic_loss(ic_out_logits, ic_targets)
                all_loss += loss.item()

                # 4 backward
                self.ic_model.zero_grad()
                loss.backward()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f}".format(
                epoch, all_loss / len(self.ic_train_loader)))
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.ic_val_freq == 0:
                self.ic_model.eval()

                val_accuracy = self.test_tool_ic.val(epoch=epoch)
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
コード例 #8
0
    _data_root_path = "/mnt/4T/ALISURE/DUTS"

    _train_print_freq = 1000
    _test_print_freq = 1000
    _num_workers = 10
    _use_gpu = True

    # _gpu_id = "0"
    # _gpu_id = "1"
    _gpu_id = "2"
    # _gpu_id = "3"

    _epochs = 30  # Super Param Group 1
    _is_sgd = False
    _weight_decay = 5e-4
    # _lr = [[0, 5e-05], [20, 5e-06]]
    _lr = [[0, 1e-5], [20, 1e-6]]

    _sp_size, _down_ratio = 4, 4

    _root_ckpt_dir = "./ckpt/PYG_ResNet1_NoBN/{}".format(_gpu_id)
    Tools.print("epochs:{} ckpt:{} sp size:{} down_ratio:{} workers:{} gpu:{} is_sgd:{} weight_decay:{}".format(
        _epochs, _root_ckpt_dir, _sp_size, _down_ratio, _num_workers, _gpu_id, _is_sgd, _weight_decay))

    runner = RunnerSPE(data_root_path=_data_root_path, root_ckpt_dir=_root_ckpt_dir,
                       sp_size=_sp_size, is_sgd=_is_sgd, lr=_lr, down_ratio=_down_ratio, weight_decay=_weight_decay,
                       train_print_freq=_train_print_freq, test_print_freq=_test_print_freq,
                       num_workers=_num_workers, use_gpu=_use_gpu, gpu_id=_gpu_id)
    runner.train(_epochs, start_epoch=0)
    pass
コード例 #9
0
    def train_mlc(self, start_epoch=0, model_file_name=None):

        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.mlc_save_result_txt)
            self.load_model(model_file_name)
            pass

        self.eval_mlc(epoch=0)

        for epoch in range(start_epoch, self.config.mlc_epoch_num):
            Tools.print()
            self._adjust_learning_rate(
                self.optimizer,
                epoch,
                lr=self.config.mlc_lr,
                change_epoch=self.config.mlc_change_epoch)
            Tools.print('Epoch:{:03d}, lr={:.6f}'.format(
                epoch, self.optimizer.param_groups[0]['lr']),
                        txt_path=self.config.mlc_save_result_txt)

            ###########################################################################
            # 1 训练模型
            all_loss = 0.0
            self.net.train()
            self.dataset_mlc_train.reset()
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_mlc_train),
                                    total=len(self.data_loader_mlc_train)):
                inputs, labels = inputs.type(
                    torch.FloatTensor).cuda(), labels.cuda()
                self.optimizer.zero_grad()

                result = self.net(inputs)
                loss = self.bce_loss(result, labels)

                loss.backward()
                self.optimizer.step()

                all_loss += loss.item()
                pass
            ###########################################################################

            Tools.print("[E:{:3d}/{:3d}] mlc loss:{:.4f}".format(
                epoch, self.config.mlc_epoch_num,
                all_loss / len(self.data_loader_mlc_train)),
                        txt_path=self.config.mlc_save_result_txt)

            # 2 保存模型
            if epoch % self.config.mlc_save_epoch_freq == 0:
                Tools.print()
                save_file_name = Tools.new_dir(
                    os.path.join(self.config.mlc_model_dir,
                                 "mlc_{}.pth".format(epoch)))
                torch.save(self.net.state_dict(), save_file_name)
                Tools.print("Save Model to {}".format(save_file_name),
                            txt_path=self.config.mlc_save_result_txt)
                Tools.print()
                pass
            ###########################################################################

            ###########################################################################
            # 3 评估模型
            if epoch % self.config.mlc_eval_epoch_freq == 0:
                self.eval_mlc(epoch=epoch)
                pass
            ###########################################################################

            pass

        # Final Save
        Tools.print()
        save_file_name = Tools.new_dir(
            os.path.join(self.config.mlc_model_dir,
                         "mlc_final_{}.pth".format(self.config.mlc_epoch_num)))
        torch.save(self.net.state_dict(), save_file_name)
        Tools.print("Save Model to {}".format(save_file_name),
                    txt_path=self.config.mlc_save_result_txt)
        Tools.print()

        self.eval_mlc(epoch=self.config.mlc_epoch_num)
        pass
コード例 #10
0
    def main(self):
        tsne = TSNE(n_components=2)

        Tools.print("begin to fit_transform {}".format(self.config.result_png))
        if os.path.exists(self.config.result_pkl):
            Tools.print("exist pkl, and now to load")
            result = Tools.read_from_pkl(self.config.result_pkl)
        else:
            Tools.print("not exist pkl, and now to fit")
            data, label = self.deal_feature()
            fit = tsne.fit_transform(data)
            result = {"fit": fit, "label": label}
            Tools.write_to_pkl(self.config.result_pkl, result)
            pass

        Tools.print("begin to embedding")
        fig = self.plot_embedding(result["fit"], result["label"])

        Tools.print("begin to save")
        plt.savefig(self.config.result_png)
        # plt.show(fig)
        pass
コード例 #11
0
            Tools.new_dir(one.replace("train_final", "train_final_resize")))
        pass
    pass


all_image = glob(os.path.join(cam_path, "**/*.png"), recursive=True)
split_id = 1000

pools = multiprocessing.Pool(processes=multiprocessing.cpu_count())
for i in range(split_id + 1):
    now_image = all_image[len(all_image) // split_id * i:len(all_image) //
                          split_id * (i + 1)]
    pools.apply_async(main, args=(i, now_image))
    pass

Tools.print("begin")
pools.close()
pools.join()
Tools.print("over")

# Check
for one in tqdm(all_image):
    if not os.path.exists(one.replace("train_final", "train_final_resize")):
        im = Image.open(
            one.replace(
                cam_path,
                os.path.join(voc12_root,
                             "ILSVRC2017_DET/ILSVRC/Data/DET")).replace(
                                 ".png", ".JPEG"))
        Image.open(one).resize(im.size, resample=Image.NEAREST).save(
            Tools.new_dir(one.replace("train_final", "train_final_resize")))
import os
import numpy as np
from glob import glob
from tqdm import tqdm
from alisuretool.Tools import Tools
from deal_data_0_global_info import get_data_root_path, get_project_path


if __name__ == '__main__':
    data_root = get_data_root_path()
    image_info_path = os.path.join(data_root, "deal", "image_info_list2.pkl")
    person_pkl = os.path.join(data_root, "deal", "person2.pkl")
    result_image_info_path = os.path.join(data_root, "deal", "image_info_list_change_person2.pkl")

    image_info_list = Tools.read_from_pkl(image_info_path)
    person_info_list = Tools.read_from_pkl(person_pkl)

    result_image_info_list = []
    for i, (image_info, person_info) in tqdm(enumerate(zip(image_info_list, person_info_list)), total=len(image_info_list)):
        if not os.path.exists(image_info["image_path"]) or image_info["image_path"] != person_info[1]:
            Tools.print(image_info["image_path"])
            pass
        image_label = list(set([one[2] for one in image_info["object"]]+ ([124] if person_info[0] == 1 else [])))
        image_path = image_info["image_path"]
        result_image_info_list.append([image_label, image_path])
        pass

    Tools.write_to_pkl(_path=result_image_info_path, _data=result_image_info_list)
    pass
コード例 #13
0
    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        try:
            self.feature_encoder.eval()
            self.relation_network.eval()
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)
                relations, query_features = self.compare_fsl(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2))
        finally:
            pass

        for epoch in range(Config.train_epoch):
            self.feature_encoder.train()
            self.relation_network.train()
            self.ic_model.train()

            Tools.print()
            self.produce_class.reset()
            Tools.print(self.task_train.classes)
            is_ok_total, is_ok_acc = 0, 0
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations, query_features = self.compare_fsl(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = self.fsl_loss(relations,
                                         task_labels) * Config.loss_fsl_ratio
                loss_ic = self.ic_loss(ic_out_logits,
                                       ic_targets) * Config.loss_ic_ratio
                loss = loss_fsl + loss_ic
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.feature_encoder.zero_grad()
                self.relation_network.zero_grad()
                self.ic_model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.feature_encoder.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(
                    self.relation_network.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5)
                self.feature_encoder_optim.step()
                self.relation_network_optim.step()
                self.ic_model_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(
                    torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print(
                "{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} ok:{:.3f}({}/{}) lr:{}".
                format(epoch + 1, all_loss / len(self.task_train_loader),
                       all_loss_fsl / len(self.task_train_loader),
                       all_loss_ic / len(self.task_train_loader),
                       int(is_ok_acc) / int(is_ok_total),
                       is_ok_acc, is_ok_total,
                       self.feature_encoder_scheduler.get_last_lr()))
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2))
            self.feature_encoder_scheduler.step()
            self.relation_network_scheduler.step()
            self.ic_model_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.feature_encoder.eval()
                self.relation_network.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True)

                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.feature_encoder.state_dict(),
                               Config.fe_dir)
                    torch.save(self.relation_network.state_dict(),
                               Config.rn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
コード例 #14
0
result_name = "08-14-2"
cap = cv2.VideoCapture(1)

kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
background_sub_tractor_mog2 = cv2.createBackgroundSubtractorMOG2()

count = 0
while True:
    ret, frame = cap.read()
    if ret:
        mask = background_sub_tractor_mog2.apply(frame)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)

        if np.sum(mask // 125) > 1000:
            count += 1
            Tools.print("{} {}".format(Tools.get_format_time(), count))
            if count > 5:
                file_name = "{}_{}.png".format(
                    Tools.get_format_time().replace(":", "_"), count)
                result_file_name = os.path.join(result_dir, result_name,
                                                file_name)
                cv2.imwrite(Tools.new_dir(result_file_name), frame)
            pass

        cv2.imshow('frame', frame)
        cv2.imshow('mask', mask)

        if cv2.waitKey(30) & 0xff == "q":
            break
    else:
        break
コード例 #15
0
 def load_model(self, model_file_name):
     self.model.load_state_dict(torch.load(model_file_name), strict=False)
     Tools.print('Load Model: {}'.format(model_file_name))
     pass
コード例 #16
0
    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(Config.train_epoch):
            self.proto_net.train()

            Tools.print()
            all_loss, all_loss_triple, all_loss_mixup = 0.0, 0.0, 0.0
            for task_tuple, inputs, task_index in tqdm(self.task_train_loader):
                batch_size, num, c, w, h = inputs.shape

                # beta = np.random.beta(1, 1, [batch_size, num])  # 64, 3
                beta = np.zeros([batch_size, num]) + 0.5  # 64, 3

                beta_lambda = np.hstack([beta, 1 - beta])  # 64, 6
                beta_lambda_tile = np.tile(beta_lambda[..., None, None, None],
                                           [c, w, h])

                inputs_1 = torch.cat(
                    [inputs, inputs[:, 1:, ...], inputs[:, 0:1, ...]],
                    dim=1) * beta_lambda_tile
                inputs_1 = (inputs_1[:, 0:num, ...] +
                            inputs_1[:, num:, ...]).float()
                now_inputs = torch.cat([inputs, inputs_1],
                                       dim=1).view(-1, c, w, h)
                now_inputs = RunnerTool.to_cuda(now_inputs)

                # 1 calculate features
                net_out = self.proto_net(now_inputs)
                net_out = self.norm(net_out)

                _, out_c, out_w, out_h = net_out.shape
                z = net_out.view(batch_size, -1, out_c, out_w, out_h)

                # 2 calculate loss
                loss, loss_triple, loss_mixup = self.mixup_loss(
                    z, beta_lambda=beta_lambda)
                all_loss += loss.item()
                all_loss_mixup += loss_mixup.item()
                all_loss_triple += loss_triple.item()

                # 3 backward
                self.proto_net.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(),
                                               0.5)
                self.proto_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print(
                "{:6} loss:{:.3f} triple:{:.3f} mixup:{:.3f} lr:{}".format(
                    epoch + 1, all_loss / len(self.task_train_loader),
                    all_loss_triple / len(self.task_train_loader),
                    all_loss_mixup / len(self.task_train_loader),
                    self.proto_net_scheduler.get_last_lr()))

            self.proto_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))
                self.proto_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
コード例 #17
0
ln3, = plt.plot(x_data, y_data3, color='blue', linewidth=2.0, linestyle='-')
ln4, = plt.plot(x_data, y_data4, color='blue', linewidth=2.0, linestyle='-.')

plt.legend(handles=[
    ln1, ln2, ln3, ln4, ln1, ln2, ln3, ln4, ln1, ln2, ln3, ln4, ln1, ln2, ln3,
    ln4
],
           labels=[
               "AAA", "BBB", "AAA", "BBB", "AAA", "BBB", "AAA", "BBB", "AAA",
               "BBB", "AAA", "BBB", "AAA", "BBB", "AAA", "BBB"
           ],
           loc='best',
           ncol=2,
           fontsize=14)
plt.grid(linestyle='--')

plt.ylim(0.2, 1.0)
plt.locator_params("y", nbins=10)
# plt.xlabel('Shot', fontsize=16)
# plt.ylabel('Accuracy', fontsize=16)
plt.tick_params(labelsize=14)
plt.subplots_adjust(top=0.96,
                    bottom=0.06,
                    left=0.08,
                    right=0.98,
                    hspace=0,
                    wspace=0)

plt.savefig(Tools.new_dir(os.path.join("plot", "shot", "demo_shot.pdf")))
plt.show()
コード例 #18
0
    def train(self, epochs, start_epoch=0):
        test_loss, test_loss1, test_loss2, test_mae, test_score, test_mae2, test_score2 = self.test()
        Tools.print('E:{:2d}, Test  sod-mae-score={:.4f}-{:.4f} '
                    'gcn-mae-score={:.4f}-{:.4f} loss={:.4f}({:.4f}+{:.4f})'.format(
            0, test_mae, test_score, test_mae2, test_score2, test_loss, test_loss1, test_loss2))

        for epoch in range(start_epoch, epochs):
            Tools.print()
            Tools.print("Start Epoch {}".format(epoch))

            self._lr(epoch)
            Tools.print('Epoch:{:02d},lr={:.4f}'.format(epoch, self.optimizer.param_groups[0]['lr']))

            (train_loss, train_loss1, train_loss2,
             train_mae, train_score, train_mae2, train_score2) = self._train_epoch()
            self._save_checkpoint(self.model, self.root_ckpt_dir, epoch)
            test_loss, test_loss1, test_loss2, test_mae, test_score, test_mae2, test_score2 = self.test()

            Tools.print('E:{:2d}, Train sod-mae-score={:.4f}-{:.4f} '
                        'gcn-mae-score={:.4f}-{:.4f} loss={:.4f}({:.4f}+{:.4f})'.format(
                epoch, train_mae, train_score, train_mae2, train_score2, train_loss, train_loss1, train_loss2))
            Tools.print('E:{:2d}, Test  sod-mae-score={:.4f}-{:.4f} '
                        'gcn-mae-score={:.4f}-{:.4f} loss={:.4f}({:.4f}+{:.4f})'.format(
                epoch, test_mae, test_score, test_mae2, test_score2, test_loss, test_loss1, test_loss2))
            pass
        pass
コード例 #19
0
    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(1, 1 + Config.train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss = 0.0
            for task_data, task_labels, task_index in tqdm(self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                predicts = self.matching(task_data)

                # 2 loss
                loss = self.loss(predicts, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} lr:{}".format(
                epoch, all_loss / len(self.task_train_loader), self.matching_net_scheduler.get_last_lr()))

            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(epoch, Config.model_name))
                self.matching_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True, has_test=False)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
コード例 #20
0
    def inference_ss(self,
                     model_file_name=None,
                     data_loader=None,
                     save_path=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        final_save_path = Tools.new_dir("{}_final".format(save_path))

        self.net.eval()
        metrics = StreamSegMetrics(self.config.ss_num_classes)
        with torch.no_grad():
            for i, (inputs, labels,
                    image_info_list) in tqdm(enumerate(data_loader),
                                             total=len(data_loader)):
                assert len(image_info_list) == 1

                # 标签
                max_size = 1000
                size = Image.open(image_info_list[0]).size
                basename = os.path.basename(image_info_list[0])
                final_name = os.path.join(final_save_path,
                                          basename.replace(".JPEG", ".png"))
                if os.path.exists(final_name):
                    continue

                if size[0] < max_size and size[1] < max_size:
                    targets = F.interpolate(torch.unsqueeze(
                        labels[0].float().cuda(), dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest").detach().cpu()
                else:
                    targets = F.interpolate(torch.unsqueeze(labels[0].float(),
                                                            dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest")
                targets = targets[0].long().numpy()

                # 预测
                outputs = 0
                for input_index, input_one in enumerate(inputs):
                    output_one = self.net(input_one.float().cuda())
                    if size[0] < max_size and size[1] < max_size:
                        outputs += F.interpolate(
                            output_one,
                            size=(size[1], size[0]),
                            mode="bilinear",
                            align_corners=False).detach().cpu()
                    else:
                        outputs += F.interpolate(output_one.detach().cpu(),
                                                 size=(size[1], size[0]),
                                                 mode="bilinear",
                                                 align_corners=False)
                        pass
                    pass
                outputs = outputs / len(inputs)
                preds = outputs.max(dim=1)[1].numpy()

                # 计算
                metrics.update(targets, preds)

                if save_path:
                    Image.open(image_info_list[0]).save(
                        os.path.join(save_path, basename))
                    DataUtil.gray_to_color(
                        np.asarray(targets[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", "_l.png")))
                    DataUtil.gray_to_color(np.asarray(
                        preds[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", ".png")))
                    Image.fromarray(np.asarray(
                        preds[0], dtype=np.uint8)).save(final_name)
                    pass
                pass
            pass

        score = metrics.get_results()
        Tools.print("{}".format(metrics.to_str(score)),
                    txt_path=self.config.ss_save_result_txt)
        return score
    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        for epoch in range(1, 1 + Config.fsl_train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss, is_ok_total, is_ok_acc = 0.0, 0, 0
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.matching(task_data)

                # 3 loss
                loss = self.fsl_loss(relations, task_labels)
                all_loss += loss.item()

                # 4 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(
                    torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} ok:{:.3f}({}/{}) lr:{}".format(
                epoch, all_loss / len(self.task_train_loader),
                int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total,
                self.matching_net_scheduler.get_last_lr()))
            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.fsl_val_freq == 0:
                self.matching_net.eval()

                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True,
                                                      has_test=False)

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
コード例 #22
0
    def __init__(self):
        # self.gpu_id_1, self.gpu_id_4 = "0", "0, 1, 2, 3"
        # self.gpu_id_1, self.gpu_id_4 = "1", "0, 1, 2, 3"
        self.gpu_id_1, self.gpu_id_4 = "2", "0, 1, 2, 3"
        # self.gpu_id_1, self.gpu_id_4 = "3", "0, 1, 2, 3"

        # 流程控制
        self.only_train_ss = False  # 是否训练SS
        self.is_balance_data = True
        self.only_eval_ss = False  # 是否评估SS
        self.only_inference_ss = True  # 是否推理SS

        # 1 原始数据训练
        # self.scales = (1.0, 0.5, 1.5)
        # self.model_file_name = "../../../WSS_Model_SS/4_DeepLabV3PlusResNet50_201_10_18_1_352/ss_final_10.pth"
        # self.eval_save_path = "../../../WSS_Model_SS_EVAL/4_DeepLabV3PlusResNet50_201_10_18_1_352/ss_final_10_scales"

        # 2 平衡数据训练
        # self.scales = (1.0, 0.5, 1.5)
        # self.model_file_name = "../../../WSS_Model_SS/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8.pth"
        # self.eval_save_path = "../../../WSS_Model_SS_EVAL/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8_scales"

        # 3 平衡数据训练
        # self.scales = (1.0, 0.5, 1.5, 2.0)
        # self.model_file_name = "../../../WSS_Model_SS/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8.pth"
        # self.eval_save_path = "../../../WSS_Model_SS_EVAL/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8_scales_4"

        # 4 平衡数据训练
        # self.scales = (1.0, 0.75, 0.5, 1.25, 1.5, 1.75, 2.0)
        # self.model_file_name = "../../../WSS_Model_SS/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8.pth"
        # self.eval_save_path = "../../../WSS_Model_SS_EVAL/6_DeepLabV3PlusResNet50_201_10_18_1_352_balance/ss_8_scales_7"

        # 4 平衡数据训练
        self.scales = (1.0, 0.75, 0.5, 1.25, 1.5, 1.75, 2.0)
        self.model_file_name = "../../../WSS_Model_SS/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_7.pth"
        self.eval_save_path = "../../../WSS_Model_SS_EVAL/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_7_scales_7"
        # self.model_file_name = "../../../WSS_Model_SS/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_final_10.pth"
        # self.eval_save_path = "../../../WSS_Model_SS_EVAL/7_DeepLabV3PlusResNet101_201_10_18_1_352_balance/ss_10_scales_7"

        # 其他方法生成的伪标签
        # self.label_path = "/mnt/4T/ALISURE/USS/WSS_CAM/cam/1_CAMNet_200_32_256_0.5"
        # self.label_path = "/mnt/4T/ALISURE/USS/WSS_CAM/cam_4/1_200_32_256_0.5"
        # self.label_path = "/mnt/4T/ALISURE/USS/WSS_CAM/cam_4/2_1_200_32_256_0.5"
        self.label_path = "/media/ubuntu/4T/ALISURE/USS/ConTa/pseudo_mask/result/2/sem_seg"

        self.ss_num_classes = 201
        self.ss_epoch_num = 10
        self.ss_milestones = [5, 8]
        self.ss_batch_size = 6 * (len(self.gpu_id_4.split(",")) - 1)
        self.ss_lr = 0.001
        self.ss_save_epoch_freq = 1
        self.ss_eval_epoch_freq = 1

        # 图像大小
        self.ss_size = 352
        self.output_stride = 16

        # 网络
        self.Net = DeepLabV3Plus
        # self.arch, self.arch_name = deeplabv3_resnet50, "DeepLabV3PlusResNet50"
        self.arch, self.arch_name = deeplabv3plus_resnet101, "DeepLabV3PlusResNet101"

        self.data_root_path = self.get_data_root_path()
        os.environ["CUDA_VISIBLE_DEVICES"] = str(
            self.gpu_id_4) if self.only_train_ss else str(self.gpu_id_1)

        run_name = "7"
        self.model_name = "{}_{}_{}_{}_{}_{}_{}{}".format(
            run_name, self.arch_name, self.ss_num_classes, self.ss_epoch_num,
            self.ss_batch_size, self.ss_save_epoch_freq, self.ss_size,
            "_balance" if self.is_balance_data else "")
        Tools.print(self.model_name)

        self.ss_model_dir = "../../../WSS_Model_SS/{}".format(self.model_name)
        self.ss_save_result_txt = Tools.new_dir("{}/result.txt".format(
            self.ss_model_dir))
        pass
コード例 #23
0
    _is_sgd = False
    _weight_decay = 5e-4
    _lr = [[0, 5e-5], [20, 5e-6]]

    _improved = True
    _has_bn = True
    _has_residual = True
    _is_normalize = True
    _concat = True

    _sp_size, _down_ratio = 3, 8

    _root_ckpt_dir = "./ckpt/PYG_ResNet_GCN_more/sigmoid_{}".format(_gpu_id)
    Tools.print(
        "epochs:{} ckpt:{} sp size:{} down_ratio:{} workers:{} gpu:{} has_residual:{} "
        "is_normalize:{} has_bn:{} improved:{} concat:{} is_sgd:{} weight_decay:{}"
        .format(_epochs, _root_ckpt_dir, _sp_size, _down_ratio, _num_workers,
                _gpu_id, _has_residual, _is_normalize, _has_bn, _improved,
                _concat, _is_sgd, _weight_decay))

    runner = RunnerSPE(data_root_path=_data_root_path,
                       root_ckpt_dir=_root_ckpt_dir,
                       sp_size=_sp_size,
                       is_sgd=_is_sgd,
                       lr=_lr,
                       residual=_has_residual,
                       normalize=_is_normalize,
                       down_ratio=_down_ratio,
                       has_bn=_has_bn,
                       concat=_concat,
                       weight_decay=_weight_decay,
                       train_print_freq=_train_print_freq,
コード例 #24
0
    def train_ss(self, start_epoch=0, model_file_name=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        # self.eval_ss(epoch=0)

        for epoch in range(start_epoch, self.config.ss_epoch_num):
            Tools.print()
            Tools.print('Epoch:{:2d}, lr={:.6f} lr2={:.6f}'.format(
                epoch, self.optimizer.param_groups[0]['lr'],
                self.optimizer.param_groups[1]['lr']),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 1 训练模型
            all_loss = 0.0
            self.net.train()
            if self.config.is_balance_data:
                self.dataset_ss_train.reset()
                pass
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_ss_train),
                                    total=len(self.data_loader_ss_train)):
                inputs, labels = inputs.float().cuda(), labels.long().cuda()
                self.optimizer.zero_grad()

                result = self.net(inputs)
                loss = self.ce_loss(result, labels)

                loss.backward()
                self.optimizer.step()

                all_loss += loss.item()

                if (i + 1) % (len(self.data_loader_ss_train) // 10) == 0:
                    self.eval_ss(epoch=epoch)
                    pass
                pass
            self.scheduler.step()
            ###########################################################################

            Tools.print("[E:{:3d}/{:3d}] ss loss:{:.4f}".format(
                epoch, self.config.ss_epoch_num,
                all_loss / len(self.data_loader_ss_train)),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 2 保存模型
            if epoch % self.config.ss_save_epoch_freq == 0:
                Tools.print()
                save_file_name = Tools.new_dir(
                    os.path.join(self.config.ss_model_dir,
                                 "ss_{}.pth".format(epoch)))
                torch.save(self.net.state_dict(), save_file_name)
                Tools.print("Save Model to {}".format(save_file_name),
                            txt_path=self.config.ss_save_result_txt)
                Tools.print()
                pass
            ###########################################################################

            ###########################################################################
            # 3 评估模型
            if epoch % self.config.ss_eval_epoch_freq == 0:
                self.eval_ss(epoch=epoch)
                pass
            ###########################################################################

            pass

        # Final Save
        Tools.print()
        save_file_name = Tools.new_dir(
            os.path.join(self.config.ss_model_dir,
                         "ss_final_{}.pth".format(self.config.ss_epoch_num)))
        torch.save(self.net.state_dict(), save_file_name)
        Tools.print("Save Model to {}".format(save_file_name),
                    txt_path=self.config.ss_save_result_txt)
        Tools.print()

        self.eval_ss(epoch=self.config.ss_epoch_num)
        pass
コード例 #25
0
 def load_model(self, model_file_name):
     ckpt = torch.load(model_file_name, map_location=self.device)
     self.model.load_state_dict(ckpt, strict=False)
     Tools.print('Load Model: {}'.format(model_file_name))
     pass
コード例 #26
0
 def load_model(self):
     if os.path.exists(Config.ic_dir):
         self.ic_model.load_state_dict(torch.load(Config.ic_dir))
         Tools.print("load ic model success from {}".format(Config.ic_dir))
     pass
コード例 #27
0
    def test(self, model_file=None, is_train_loader=False):
        self.model.eval()

        if model_file:
            self.load_model(model_file_name=model_file)

        Tools.print()
        th_num = 25

        # 统计
        epoch_test_loss, epoch_test_loss1, epoch_test_loss2, nb_data = 0, 0, 0, 0
        epoch_test_mae, epoch_test_mae2 = 0.0, 0.0
        epoch_test_prec, epoch_test_recall = np.zeros(
            shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6
        epoch_test_prec2, epoch_test_recall2 = np.zeros(
            shape=(th_num, )) + 1e-6, np.zeros(shape=(th_num, )) + 1e-6

        loader = self.train_loader if is_train_loader else self.test_loader
        tr_num = len(loader)
        with torch.no_grad():
            for i, (images, _, labels_sod, batched_graph, batched_pixel_graph,
                    segments, _, _) in enumerate(loader):
                # Data
                images = images.float().to(self.device)
                labels = batched_graph.y.to(self.device)
                labels_sod = torch.unsqueeze(torch.Tensor(labels_sod),
                                             dim=1).to(self.device)
                batched_graph.batch = batched_graph.batch.to(self.device)
                batched_graph.edge_index = batched_graph.edge_index.to(
                    self.device)

                batched_pixel_graph.batch = batched_pixel_graph.batch.to(
                    self.device)
                batched_pixel_graph.edge_index = batched_pixel_graph.edge_index.to(
                    self.device)
                batched_pixel_graph.data_where = batched_pixel_graph.data_where.to(
                    self.device)

                _, gcn_logits_sigmoid, _, _, sod_logits_sigmoid = self.model.forward(
                    images, batched_graph, batched_pixel_graph)

                loss1 = self.loss_bce(gcn_logits_sigmoid, labels)
                loss2 = self.loss_bce(sod_logits_sigmoid, labels_sod)
                loss = loss1 + loss2

                labels_val = labels.cpu().detach().numpy()
                labels_sod_val = labels_sod.cpu().detach().numpy()
                gcn_logits_sigmoid_val = gcn_logits_sigmoid.cpu().detach(
                ).numpy()
                sod_logits_sigmoid_val = sod_logits_sigmoid.cpu().detach(
                ).numpy()

                # Stat
                nb_data += images.size(0)
                epoch_test_loss += loss.detach().item()
                epoch_test_loss1 += loss1.detach().item()
                epoch_test_loss2 += loss2.detach().item()

                # cal 1
                mae = self._eval_mae(sod_logits_sigmoid_val, labels_sod_val)
                prec, recall = self._eval_pr(sod_logits_sigmoid_val,
                                             labels_sod_val, th_num)
                epoch_test_mae += mae
                epoch_test_prec += prec
                epoch_test_recall += recall

                # cal 2
                mae2 = self._eval_mae(gcn_logits_sigmoid_val, labels_val)
                prec2, recall2 = self._eval_pr(gcn_logits_sigmoid_val,
                                               labels_val, th_num)
                epoch_test_mae2 += mae2
                epoch_test_prec2 += prec2
                epoch_test_recall2 += recall2

                # Print
                if i % self.test_print_freq == 0:
                    Tools.print(
                        "{:4d}-{:4d} loss={:.4f}({:.4f}+{:.4f})-{:.4f}({:.4f}+{:.4f}) "
                        "sod-mse={:.4f}({:.4f}) gcn-mse={:.4f}({:.4f})".format(
                            i, len(loader),
                            loss.detach().item(),
                            loss1.detach().item(),
                            loss2.detach().item(), epoch_test_loss / (i + 1),
                            epoch_test_loss1 / (i + 1),
                            epoch_test_loss2 / (i + 1), mae,
                            epoch_test_mae / (i + 1), mae2,
                            epoch_test_mae2 / nb_data))
                    pass
                pass
            pass

        # 结果1
        avg_loss, avg_loss1, avg_loss2 = epoch_test_loss / tr_num, epoch_test_loss1 / tr_num, epoch_test_loss2 / tr_num

        avg_mae, avg_prec, avg_recall = epoch_test_mae / tr_num, epoch_test_prec / tr_num, epoch_test_recall / tr_num
        score = (1 + 0.3) * avg_prec * avg_recall / (0.3 * avg_prec +
                                                     avg_recall)
        avg_mae2, avg_prec2, avg_recall2 = epoch_test_mae2 / nb_data, epoch_test_prec2 / nb_data, epoch_test_recall2 / nb_data
        score2 = (1 + 0.3) * avg_prec2 * avg_recall2 / (0.3 * avg_prec2 +
                                                        avg_recall2)

        return avg_loss, avg_loss1, avg_loss2, avg_mae, score.max(
        ), avg_mae2, score2.max()
コード例 #28
0
    # _data_root_path = '/home/ubuntu/ALISURE/data/mnist'
    _data_root_path = "/private/alishuo/mnist"
    _root_ckpt_dir = "./ckpt2/dgl/4_DGL_CONV-mnist3-Adam/{}".format(
        "GCNNet-C1")
    _batch_size = 64
    _image_size = 28
    _sp_size = 4
    _train_print_freq = 100
    _test_print_freq = 100
    _num_workers = 6
    _use_gpu = True
    _gpu_id = "0"
    # _gpu_id = "1"

    Tools.print(
        "ckpt:{} batch size:{} image size:{} sp size:{} workers:{} gpu:{}".
        format(_root_ckpt_dir, _batch_size, _image_size, _sp_size,
               _num_workers, _gpu_id))

    runner = RunnerSPE(data_root_path=_data_root_path,
                       root_ckpt_dir=_root_ckpt_dir,
                       batch_size=_batch_size,
                       image_size=_image_size,
                       sp_size=_sp_size,
                       train_print_freq=_train_print_freq,
                       test_print_freq=_test_print_freq,
                       num_workers=_num_workers,
                       use_gpu=_use_gpu,
                       gpu_id=_gpu_id)
    runner.train(150)

    pass
コード例 #29
0
    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(1, 1 + Config.train_epoch):
            self.proto_net.train()

            Tools.print()
            all_loss = 0.0
            pn_lr = self.adjust_learning_rate(self.proto_net_optim, epoch,
                                              Config.first_epoch,
                                              Config.t_epoch,
                                              Config.learning_rate)
            Tools.print('Epoch: [{}] pn_lr={}'.format(epoch, pn_lr))

            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                out = self.proto(task_data)

                # 2 loss
                loss = self.loss(out, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.proto_net.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(),
                                               0.5)
                self.proto_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f}".format(
                epoch, all_loss / len(self.task_train_loader)))
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))

                self.proto_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass
コード例 #30
0
    def train_val_pipeline(self):
        t0, per_epoch_time = time.time(), []
        epoch_train_losses, epoch_val_losses, epoch_train_accs, epoch_val_accs = [], [], [], []
        for epoch in range(self.params.epochs):
            start = time.time()
            Tools.print()
            Tools.print("Start Epoch {}".format(epoch))

            epoch_train_loss, epoch_train_acc = self.train_epoch(
                self.train_loader)
            epoch_val_loss, epoch_val_acc = self.evaluate_network(
                self.val_loader)
            epoch_test_loss, epoch_test_acc = self.evaluate_network(
                self.test_loader)

            self.scheduler.step(epoch_val_loss)
            self.save_checkpoint(self.model, self.params.root_ckpt_dir, epoch)

            epoch_train_losses.append(epoch_train_loss)
            epoch_val_losses.append(epoch_val_loss)
            epoch_train_accs.append(epoch_train_acc)
            epoch_val_accs.append(epoch_val_acc)

            self.writer.add_scalar('train/_loss', epoch_train_loss, epoch)
            self.writer.add_scalar('val/_loss', epoch_val_loss, epoch)
            self.writer.add_scalar('train/_acc', epoch_train_acc, epoch)
            self.writer.add_scalar('val/_acc', epoch_val_acc, epoch)
            self.writer.add_scalar('learning_rate',
                                   self.optimizer.param_groups[0]['lr'], epoch)

            per_epoch_time.append(time.time() - start)
            Tools.print(
                "time={:.4f}, lr={:.4f}, loss={:.4f}/{:.4f}/{:.4f}, acc={:.4f}/{:.4f}/{:.4f}"
                .format(time.time() - start,
                        self.optimizer.param_groups[0]['lr'], epoch_train_loss,
                        epoch_val_loss, epoch_test_loss, epoch_train_acc,
                        epoch_val_acc, epoch_test_acc))

            # Stop training
            if self.optimizer.param_groups[0]['lr'] < self.params.min_lr:
                Tools.print()
                Tools.print("\n!! LR EQUAL TO MIN LR SET.")
                break
            if time.time() - t0 > self.params.max_time * 3600:
                Tools.print()
                Tools.print(
                    "Max_time for training elapsed {:.2f} hours, so stopping".
                    format(self.params.max_time))
                break

            pass

        _, val_acc = self.evaluate_network(self.val_loader)
        _, test_acc = self.evaluate_network(self.test_loader)
        _, train_acc = self.evaluate_network(self.train_loader)

        Tools.print()
        Tools.print("Val Accuracy: {:.4f}".format(val_acc))
        Tools.print("Test Accuracy: {:.4f}".format(test_acc))
        Tools.print("Train Accuracy: {:.4f}".format(train_acc))
        Tools.print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0))
        Tools.print("AVG TIME PER EPOCH: {:.4f}s".format(
            np.mean(per_epoch_time)))

        self.writer.close()
        pass