示例#1
0
    def __init__(self):
        self.args = parser.parse_args()
        print(f"-----------{self.args.project_name}-----------")
        use_cuda = self.args.use_cuda and torch.cuda.is_available()
        if use_cuda:
            torch.cuda.manual_seed(self.args.seed)
        else:
            torch.manual_seed(self.args.seed)

        self.device = torch.device("cuda" if use_cuda else "cpu")

        kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
        '''
        构造DataLoader
        '''
        # ToDo 数据集需要重新制备
        print("Create Dataloader")
        self.train_loader = DataLoader(Dataset2(),
                                       batch_size=1,
                                       shuffle=True,
                                       **kwargs)
        self.test_loader = DataLoader(Dataset2(),
                                      batch_size=1,
                                      shuffle=True,
                                      **kwargs)
        '''
        定义模型
        '''
        print("Create Model")
        self.model = OPN().to(self.device)
        #        self.model = nn.DataParallel(OPN())
        if use_cuda:
            # self.model = self.model.cuda()
            cudnn.benchmark = True
        '''
        根据需要加载预训练的模型权重参数
        '''

        # VGG16模型配合预训练的模型用于检测
        self.vgg = models.vgg16(pretrained=True).to(self.device).features

        for i in self.vgg.parameters():
            i.requires_grad = False
        try:
            if self.args.resume and self.args.pretrained_weight:
                self.model.load_state_dict(torch.load(os.path.join('OPN.pth')),
                                           strict=False)
                print("模型加载成功")
        except:
            print("模型加载失败")
        '''
        cuda加速
        '''
        if use_cuda:
            #   self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
            cudnn.benchmark = True
        '''
        构造loss目标函数
        选择优化器
        学习率变化选择
        '''
        print("Establish the loss, optimizer and learning_rate function")
        self.loss_tv = TVLoss()
        self.loss_l1 = L1_Loss()
        # 另外还有style—loss 和 content—loss
        # self.optimizer = optim.SGD(
        #     params=self.model.parameters(),
        #     lr=self.args.lr,
        #     weight_decay=self.args.weight_decay,
        #     momentum=0.5
        # )
        self.optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=0.001,
            betas=(0.9, 0.999),
            eps=1e-8,  # 为了防止分母为0
            weight_decay=0)
        # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=5, eta_min=1e-5)
        '''
        模型开始训练
        '''
        print("Start training")
        for epoch in tqdm(range(1, self.args.epoch + 1)):
            self.train(epoch)
            if epoch % 20 == 0:
                self.test(epoch)

        torch.cuda.empty_cache()

        print("finish model training")
frames = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()
holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()
dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()
# remove hole
frames = frames * (1 - holes) + holes * torch.tensor([0.485, 0.456, 0.406
                                                      ]).view(3, 1, 1, 1)
# valids area
valids = 1 - holes
# unsqueeze to batch 1 与 datalist相匹配,不这样做 放不进模型里面
frames = frames.unsqueeze(0)
holes = holes.unsqueeze(0)
dists = dists.unsqueeze(0)
valids = valids.unsqueeze(0)

#################### Load Model
model = nn.DataParallel(OPN())
if torch.cuda.is_available():
    model.cuda()
model.load_state_dict(torch.load(os.path.join('OPN.pth')), strict=False)
model.eval()

################### Inference
# memory encoding
midx = list(range(0, T))
with torch.no_grad():
    mkey, mval, mhol = model(frames[:, :, midx], valids[:, :, midx],
                             dists[:, :, midx])

for f in range(T):
    # memory selection
    ridx = [i for i in range(len(midx)) if i != f]  # memory minus self
示例#3
0
class train(object):
    def __init__(self):
        self.args = parser.parse_args()
        print(f"-----------{self.args.project_name}-----------")
        use_cuda = self.args.use_cuda and torch.cuda.is_available()
        if use_cuda:
            torch.cuda.manual_seed(self.args.seed)
        else:
            torch.manual_seed(self.args.seed)

        self.device = torch.device("cuda" if use_cuda else "cpu")

        kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
        '''
        构造DataLoader
        '''
        # ToDo 数据集需要重新制备
        print("Create Dataloader")
        self.train_loader = DataLoader(Dataset2(),
                                       batch_size=1,
                                       shuffle=True,
                                       **kwargs)
        self.test_loader = DataLoader(Dataset2(),
                                      batch_size=1,
                                      shuffle=True,
                                      **kwargs)
        '''
        定义模型
        '''
        print("Create Model")
        self.model = OPN().to(self.device)
        #        self.model = nn.DataParallel(OPN())
        if use_cuda:
            # self.model = self.model.cuda()
            cudnn.benchmark = True
        '''
        根据需要加载预训练的模型权重参数
        '''

        # VGG16模型配合预训练的模型用于检测
        self.vgg = models.vgg16(pretrained=True).to(self.device).features

        for i in self.vgg.parameters():
            i.requires_grad = False
        try:
            if self.args.resume and self.args.pretrained_weight:
                self.model.load_state_dict(torch.load(os.path.join('OPN.pth')),
                                           strict=False)
                print("模型加载成功")
        except:
            print("模型加载失败")
        '''
        cuda加速
        '''
        if use_cuda:
            #   self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
            cudnn.benchmark = True
        '''
        构造loss目标函数
        选择优化器
        学习率变化选择
        '''
        print("Establish the loss, optimizer and learning_rate function")
        self.loss_tv = TVLoss()
        self.loss_l1 = L1_Loss()
        # 另外还有style—loss 和 content—loss
        # self.optimizer = optim.SGD(
        #     params=self.model.parameters(),
        #     lr=self.args.lr,
        #     weight_decay=self.args.weight_decay,
        #     momentum=0.5
        # )
        self.optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=0.001,
            betas=(0.9, 0.999),
            eps=1e-8,  # 为了防止分母为0
            weight_decay=0)
        # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=5, eta_min=1e-5)
        '''
        模型开始训练
        '''
        print("Start training")
        for epoch in tqdm(range(1, self.args.epoch + 1)):
            self.train(epoch)
            if epoch % 20 == 0:
                self.test(epoch)

        torch.cuda.empty_cache()

        print("finish model training")

    def train(self, epoch):
        self.model.train()
        for data in self.train_loader:

            self.content_loss = 0
            self.style_loss = 0

            midx = list(range(0, 5))
            # frames被破损的图像,valids可获取的像素区域,dists填补的像素区域
            frames, valids, dists, label = data
            frames, valids, dists, label = frames.to(self.device), valids.to(
                self.device), dists.to(self.device), label.to(self.device)
            # 每一张图片都被encoder过了获得的key和val shape为(1,128,5,60,106),hol为(1,1,5,60,106)
            mkey, mval, mhol = self.model(frames[:, :, midx],
                                          valids[:, :, midx], dists[:, :,
                                                                    midx])

            allloss = 0
            for f in range(5):
                loss = 0
                # 对每张图取其他4张图作为reference的参考
                ridx = [i for i in range(len(midx)) if i != f]
                fkey, fval, fhol = mkey[:, :, ridx], mval[:, :,
                                                          ridx], mhol[:, :,
                                                                      ridx]
                # 图像补全
                for r in range(5):
                    if r == 0:
                        # 取主图
                        comp = frames[:, :, f]
                        dist = dists[:, :, f]
                    # comp是破损的图片,逐层补全图片
                    # valids是没有缺失信息的区域
                    # dist是缺失信息的区域
                    '''
                    按dist的指导,逐8个像素的距离,循环修复图片,其中valids表示空洞部分的区域(0,1)
                    comp是在frame的基础之上补充的,相似度极高,只计算这一部分的loss
                    '''
                    comp, dist, peel = self.model(fkey, fval, fhol, comp,
                                                  valids[:, :, f], dist)
                    # 每次循环中分别在像素空间和深层特征空间最小化和GT的L1距离。
                    loss += 100 * L1(comp, label[:, :, f], peel)
                    # loss += L1(comp, label[:, :, f], valids[:,:,f])
                    loss += 0.2 * self.loss_l1(comp, label[:, :, f],
                                               valids[:, :, midx])
                    # loss+=100*ll1(comp,frames[:,:,f])

                # content loss
                content_features = get_features(frames[:, :, f], self.vgg)
                target_features = get_features(comp, self.vgg)
                self.content_loss = torch.mean(
                    torch.abs((target_features['conv4_2'] -
                               content_features['conv4_2'])))
                loss = loss + 0.05 * self.content_loss
                # style loss
                style_features = get_features(comp, self.vgg)
                style_grams = {
                    layer: gram_matrix(style_features[layer])
                    for layer in style_features
                }
                '''加上每一层的gram_matrix矩阵的损失'''
                for layer in style_weights:
                    target_feature = target_features[layer]
                    target_gram = gram_matrix(target_feature)
                    _, d, h, w = target_feature.shape
                    style_gram = style_grams[layer]
                    layer_style_loss = style_weights[layer] * torch.mean(
                        torch.abs((target_gram - style_gram)))
                    self.style_loss += layer_style_loss / (d * h * w)  # 加到
                loss = loss + 120 * self.style_loss
                # tv loss
                loss += 0.01 * self.loss_tv(comp)
                allloss += loss
            self.optimizer.zero_grad()
            allloss.backward()
            self.optimizer.step()
        # self.scheduler.step()

    # print("epoch{}".format(epoch) + "  loss:{}".format(loss.cpu()))

    def test(self, epoch):
        self.model.eval()
        for frames, valids, dists, _ in self.test_loader:
            midx = list(range(0, 5))
            # frames, valids, dists = data
            frames, valids, dists = frames.to(self.device), valids.to(
                self.device), dists.to(self.device)
            with torch.no_grad():
                # 先把这5张图片都encoder一下
                mkey, mval, mhol = self.model(frames[:, :, midx],
                                              valids[:, :, midx], dists[:, :,
                                                                        midx])
            # 对每张图取其他4张图作为reference的参考
            for f in range(5):
                ridx = [i for i in range(len(midx)) if i != f]
                fkey, fval, fhol = mkey[:, :, ridx], mval[:, :,
                                                          ridx], mhol[:, :,
                                                                      ridx]
                # 图像补全
                for r in range(999):
                    if r == 0:
                        comp = frames[:, :, f]
                        dist = dists[:, :, f]
                    with torch.no_grad():
                        comp, dist, peel = self.model(fkey, fval, fhol, comp,
                                                      valids[:, :, f], dist)

                    comp, dist = comp.detach(), dist.detach()
                    # 空隙填满进入后,把图片保存,然后进入下一轮图片的计算过程中
                    if torch.sum(dist).item() == 0:
                        break

                if self.args.save:

                    # visualize..
                    est = (comp[0].permute(1, 2, 0).detach().cpu().numpy() *
                           255.).astype(np.uint8)
                    true = (frames[0, :, f].permute(
                        1, 2, 0).detach().cpu().numpy() * 255.).astype(
                            np.uint8)  # h,w,3
                    mask = (dists[0, 0, f].detach().cpu().numpy() > 0).astype(
                        np.uint8)  # h,w,1
                    ov_true = overlay_davis(true,
                                            mask,
                                            colors=[[0, 0, 0], [100, 100, 0]],
                                            cscale=2,
                                            alpha=0.4)

                    canvas = np.concatenate([ov_true, est], axis=0)
                    save_path = os.path.join('Results')
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    canvas = Image.fromarray(canvas)
                    canvas.save(
                        os.path.join(save_path,
                                     'res_{}_{}.jpg'.format(epoch, f)))
示例#4
0
          open(os.path.join(args.output_dir, 'args.json'), mode='w'),
          ensure_ascii=False,
          indent=4,
          sort_keys=True,
          separators=(',', ': '))

# データセットを読み込む
train_loader = DataLoader(VideoSortingClassificationMatTrainDataSet(
    frame_num=4, path_list=mat_loader(args.input_dir, args.input_mat)),
                          batch_size=batch_size,
                          shuffle=True)
train_iterate_len = len(train_loader)

# 初期設定
# resnet18を取得
Net = OPN(pretrained=args.use_pretrained_model)
criterion = torch.nn.CrossEntropyLoss()  # Loss関数を定義
optimizer = torch.optim.Adam(Net.parameters(),
                             lr=args.learning_rate)  # 重み更新方法を定義
current_epoch = 0

# ログファイルの生成
if not args.no_reset_log_file:
    with open(log_train_path, mode='w') as f:
        f.write('epoch,loss,accuracy,time,learning_rate\n')

# CUDA環境の有無で処理を変更
if args.use_cuda:
    criterion = criterion.cuda()
    Net = torch.nn.DataParallel(Net.cuda())
    device = 'cuda'