コード例 #1
0
def train(train_loader, model, optimizer, epoch):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p  = data_a.float().cuda(), data_p.float().cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi)
        scale = Variable( 0.9 + 0.3* torch.rand(data_a.size(0), 1, 1));
        if args.cuda:
            scale = scale.cuda()
        rot_LAFs[:,0:2,0:2] = rot_LAFs[:,0:2,0:2] * scale.expand(data_a.size(0),2,2)
        shift_w, shift_h = get_random_shifts_LAFs(data_a, 2, 2)
        rot_LAFs[:,0,2] = rot_LAFs[:,0,2] + shift_w / float(data_a.size(3))
        rot_LAFs[:,1,2] = rot_LAFs[:,1,2] + shift_h / float(data_a.size(2))
        data_a_rot = extract_patches(data_a,  rot_LAFs, PS = data_a.size(2))
        st = int((data_p.size(2) - model.PS)/2)
        fin = st + model.PS

        data_p_crop = data_p[:,:, st:fin, st:fin].contiguous()
        data_a_rot_crop = data_a_rot[:,:, st:fin, st:fin].contiguous()
        out_a_rot, out_p, out_a = model(data_a_rot_crop,True), model(data_p_crop,True), model(data_a[:,:, st:fin, st:fin].contiguous(), True)
        out_p_rotatad = torch.bmm(inv_rotmat, out_p)

        ######Apply rot and get sifts
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS)

        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist =  torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6).mean()
        geom_dist = torch.sqrt(((out_a_rot - out_p_rotatad)**2 ).view(-1,4).sum(dim=1)[0] + 1e-8).mean()
        if args.loss == 'HardNet':
            loss = loss_HardNet(desc_a,desc_p); 
        elif args.loss == 'HardNetDetach':
            loss = loss_HardNetDetach(desc_a,desc_p); 
        elif args.loss == 'Geom':
            loss = geom_dist; 
        elif args.loss == 'PosDist':
            loss = descr_dist; 
        else:
            print('Unknown loss function')
            sys.exit(1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {:.4f},{:.4f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                    float(loss.detach().cpu().numpy()), float(geom_dist.detach().cpu().numpy()), float(descr_dist.detach().cpu().numpy())))
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}/checkpoint_{}.pth'.format(LOG_DIR,epoch))
コード例 #2
0
def train(train_loader, model, optimizer, epoch, logger, load_triplets  = False):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        #print( data)
        if load_triplets:
            data_a, data_p, data_n = data
            if args.cuda:
                data_a, data_p, data_n  = data_a.cuda(), data_p.cuda(), data_n.cuda()
            data_a, data_p, data_n = Variable(data_a), Variable(data_p), Variable(data_n)
            out_a, out_p, out_n = model(data_a), model(data_p), model(data_n)
            loss = loss_random_sampling(out_a, out_p, out_n,
                                margin=args.margin,
                                anchor_swap=args.anchorswap,
                                loss_type = args.loss)
        else:
            data_a, data_p = data
            if args.cuda:
                data_a, data_p = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a, out_p = model(data_a), model(data_p)
            #hardnet loss
            if args.batch_reduce == 'L2Net':
                loss = loss_L2Net(out_a, out_p, column_row_swap = True, anchor_swap =args.anchorswap, margin = args.margin, loss_type = args.loss)
            else:
                loss = loss_HardNet(out_a, out_p,
                                    margin=args.margin, column_row_swap = True, 
                                    anchor_swap=args.anchorswap,
                                    anchor_ave=args.anchorave,
                                    batch_reduce = args.batch_reduce,
                                    loss_type = args.loss)
        if args.decor:
            loss += CorrelationPenaltyLoss()(out_a)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if (args.enable_logging):
            logger.log_value('loss', loss.data[0]).step()

        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                    loss.data[0]))

    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
コード例 #3
0
def train(train_loader, model, optimizer, epoch, logger, load_triplets  = False):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        if load_triplets:
            data_a, data_p, data_n = data
        else:
            data_a, data_p = data

        if args.cuda:
            data_a, data_p  = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a = model(data_a)
            out_p = model(data_p)
        if load_triplets:
            data_n  = data_n.cuda()
            data_n = Variable(data_n)
            out_n = model(data_n)

        if args.batch_reduce == 'L2Net':
            loss = loss_L2Net(out_a, out_p, anchor_swap = args.anchorswap,
                    margin = args.margin, loss_type = args.loss)
        elif args.batch_reduce == 'random_global':
            loss = loss_random_sampling(out_a, out_p, out_n,
                margin=args.margin,
                anchor_swap=args.anchorswap,
                loss_type = args.loss)
        else:
            loss = loss_HardNet(out_a, out_p,
                            margin=args.margin,
                            anchor_swap=args.anchorswap,
                            anchor_ave=args.anchorave,
                            batch_reduce = args.batch_reduce,
                            loss_type = args.loss)

        if args.decor:
            loss += CorrelationPenaltyLoss()(out_a)
            
        if args.gor:
            loss += args.alpha*global_orthogonal_regularization(out_a, out_n)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                    loss.item()))

    if (args.enable_logging):
        logger.log_value('loss', loss.item()).step()

    try:
        os.stat('{}{}'.format(args.model_dir,suffix))
    except:
        os.makedirs('{}{}'.format(args.model_dir,suffix))

    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}{}/checkpoint_{}.pth'.format(args.model_dir,suffix,epoch))
    del loss, data_p, data_a, data, out_a, out_p
コード例 #4
0
def train(train_loader, model, optimizer, epoch):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p = data_a.float().cuda(), data_p.float().cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
        st = int((data_p.size(2) - model.PS) / 2)
        fin = st + model.PS
        ep1 = epoch
        while str(ep1) not in tilt_schedule.keys():
            ep1 -= 1
            if ep1 < 0:
                break
        max_tilt = tilt_schedule[str(ep1)]
        data_a_aff_crop, data_a_aff, rot_LAFs_a, inv_rotmat_a, inv_TA_a = extract_random_LAF(
            data_a, math.pi, max_tilt, model.PS)
        if 'Rot' not in args.arch:
            data_p_aff_crop, data_p_aff, rot_LAFs_p, inv_rotmat_p, inv_TA_p = extract_random_LAF(
                data_p, rot_LAFs_a, max_tilt, model.PS)
        else:
            data_p_aff_crop, data_p_aff, rot_LAFs_p, inv_rotmat_p, inv_TA_p = extract_random_LAF(
                data_p, math.pi, max_tilt, model.PS)
        if inv_rotmat_p is None:
            inv_rotmat_p = inv_rotmat_a
        out_a_aff, out_p_aff = model(data_a_aff_crop,
                                     True), model(data_p_aff_crop, True)
        #out_a_aff_back = torch.bmm(torch.bmm(out_a_aff, inv_TA_a),  inv_rotmat_a)
        #out_p_aff_back = torch.bmm(torch.bmm(out_p_aff, inv_TA_p),  inv_rotmat_p)
        ###### Get descriptors
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(
            data_a_aff, out_a_aff, crop_size=model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(
            data_p_aff, out_p_aff, crop_size=model.PS)
        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist = torch.sqrt((
            (desc_a - desc_p)**2).view(data_a.size(0), -1).sum(dim=1) +
                                1e-6).mean()
        #geom_dist = torch.sqrt(((out_a_aff_back - out_p_aff_back)**2 ).view(-1,4).sum(dim=1) + 1e-8).mean()
        if args.loss == 'HardNet':
            loss = loss_HardNet(desc_a, desc_p)
        elif args.loss == 'HardNegC':
            loss = loss_HardNegC(desc_a, desc_p)
        #elif args.loss == 'Geom':
        #    loss = geom_dist;
        elif args.loss == 'PosDist':
            loss = descr_dist
        else:
            print('Unknown loss function')
            sys.exit(1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f},{:.4f}'.
                format(epoch, batch_idx * len(data_a),
                       len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       float(loss.detach().cpu().numpy()),
                       float(descr_dist.detach().cpu().numpy())))
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
コード例 #5
0
ファイル: CDbin.py プロジェクト: Shelfcol/CDbin
def train(test_loader, train_loader, model, optimizer, epoch, logger, load_triplets  = False):
    # switch to train mode

    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        if load_triplets:
            data_a, data_p, data_n = data
        else:
            data_a, data_p = data
        # print("data_a",data_a.size())
        if args.cuda:
            data_a, data_p  = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a, out_p = model(data_a), model(data_p)

        if load_triplets:
            data_n  = data_n.cuda()
            data_n = Variable(data_n)
            out_n = model(data_n)

        if args.batch_reduce == 'L2Net':
            loss = loss_L2Net(out_a, out_p, anchor_swap = args.anchorswap,
                    margin = args.margin, loss_type = args.loss)
        elif args.batch_reduce == 'random_global':
            loss = loss_random_sampling(out_a, out_p, out_n,
                margin=args.margin,
                anchor_swap=args.anchorswap,
                loss_type = args.loss)
        else:
            loss = loss_HardNet(out_a, out_p,
                            margin=args.margin,
                            anchor_swap=args.anchorswap,
                            anchor_ave=args.anchorave,
                            batch_reduce = args.batch_reduce,
                            loss_type = args.loss)

        if args.decor:
            loss += args.cor_weights * CorrelationPenaltyLoss()(out_a)
        if args.gor:
            loss += args.alpha * global_orthogonal_regularization(out_a, out_n)
        if args.evendis:
            loss += args.even_weights * Even_distributeLoss()(out_a)
        if args.quan:
            loss += args.quan_weights * QuantilizeLoss(args.quan_scale)(out_a)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not args.constantlr:
            adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]lr:{:f} \tLoss_T: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                           optimizer.param_groups[0]['lr'], loss.data[0]))

    if (args.enable_logging):
        logger.log_value('loss', loss.data.item()).step()

    try:
        os.stat('{}{}'.format(args.model_dir,suffix))
    except:
        os.makedirs('{}{}'.format(args.model_dir,suffix))

    torch.save({'epoch': epoch + 1, 'optimizer':optimizer.state_dict()
                ,'state_dict': model.state_dict()},
               '{}{}/checkpoint_{}{}.pth'.format(args.model_dir,suffix,newstart,epoch))
    # torch.save(model,'{}{}/checkpoint_{}.pth'.format(args.model_dir,suffix,epoch))
    print("model {}{}/checkpoint_{}{}.pth is saved".format(args.model_dir,suffix,newstart,epoch))
    if (args.enable_logging):
        logger.log_value(test_loader['name']+'loss is:', loss.data[0])
    return loss.data.item()
コード例 #6
0
ファイル: train_AffNet.py プロジェクト: toanhvu/affnet
def train(train_loader, model, optimizer, epoch):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p = data_a.float().cuda(), data_p.float().cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
        st = int((data_p.size(2) - model.PS) / 2)
        fin = st + model.PS
        #
        #
        max_tilt = 3.0
        if epoch > 1:
            max_tilt = 4.0
        if epoch > 3:
            max_tilt = 4.5
        if epoch > 5:
            max_tilt = 4.8
        rot_LAFs_a, inv_rotmat_a = get_random_rotation_LAFs(data_a, math.pi)
        aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, max_tilt)
        aff_LAFs_a[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2],
                                            aff_LAFs_a[:, 0:2, 0:2])
        data_a_aff = extract_patches(data_a, aff_LAFs_a, PS=data_a.size(2))
        data_a_aff_crop = data_a_aff[:, :, st:fin, st:fin].contiguous()
        aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, max_tilt)
        aff_LAFs_p[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2],
                                            aff_LAFs_p[:, 0:2, 0:2])
        data_p_aff = extract_patches(data_p, aff_LAFs_p, PS=data_p.size(2))
        data_p_aff_crop = data_p_aff[:, :, st:fin, st:fin].contiguous()
        out_a_aff, out_p_aff = model(data_a_aff_crop,
                                     True), model(data_p_aff_crop, True)
        out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff)
        out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff)
        ######Apply rot and get sifts
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(
            data_a_aff, out_a_aff, crop_size=model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(
            data_p_aff, out_p_aff, crop_size=model.PS)
        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist = torch.sqrt((
            (desc_a - desc_p)**2).view(data_a.size(0), -1).sum(dim=1) + 1e-6)
        descr_loss = loss_HardNet(desc_a, desc_p, anchor_swap=True)
        geom_dist = torch.sqrt((
            (out_a_aff_back - out_p_aff_back)**2).view(-1, 4).mean(dim=1) +
                               1e-8)
        if args.merge == 'sum':
            loss = descr_loss
        elif args.merge == 'mul':
            loss = descr_loss
        else:
            print('Unknown merge option')
            sys.exit(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % 2 == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {},{:.4f}'.
                format(epoch, batch_idx * len(data_a),
                       len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data[0],
                       geom_dist.mean().data[0],
                       descr_dist.mean().data[0]))
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
コード例 #7
0
ファイル: HardNet.py プロジェクト: keeeeenw/hardnet
    def train(self, train_loader, model, optimizer, epoch, logger, load_triplets  = False):
        print("Training model")
        # switch to train mode
        model.train()
        pbar = tqdm(enumerate(train_loader))
        for batch_idx, data in pbar:
            if load_triplets:
                data_a, data_p, data_n = data
            else:
                data_a, data_p = data

            if self.args.cuda:
                data_a, data_p  = data_a.cuda(), data_p.cuda()
                data_a, data_p = Variable(data_a), Variable(data_p)
            out_a = model(data_a)
            out_p = model(data_p)
            if load_triplets:
                data_n  = data_n.cuda()
                data_n = Variable(data_n)
                out_n = model(data_n)

            if self.args.loss == 'qht':
                loss = loss_SOSNet(out_a, out_p,
                                   batch_reduce=self.args.batch_reduce,
                                   no_cuda=self.args.no_cuda)
            else:
                if self.args.batch_reduce == 'L2Net':
                    loss = loss_L2Net(out_a, out_p, anchor_swap = self.args.anchorswap,
                            margin = self.args.margin, loss_type = self.args.loss)
                elif self.args.batch_reduce == 'random_global':
                    loss = loss_random_sampling(out_a, out_p, out_n,
                        margin=self.args.margin,
                        anchor_swap=self.args.anchorswap,
                        loss_type = self.args.loss)
                else:
                    loss = loss_HardNet(out_a, out_p,
                                    margin=self.args.margin,
                                    anchor_swap=self.args.anchorswap,
                                    anchor_ave=self.args.anchorave,
                                    batch_reduce = self.args.batch_reduce,
                                    loss_type = self.args.loss,
                                    no_cuda = self.args.no_cuda)

            if self.args.decor:
                loss += CorrelationPenaltyLoss()(out_a)
                
            if self.args.gor:
                loss += self.args.alpha*global_orthogonal_regularization(out_a, out_n)
            
            if self.print_summary:
                with torch.no_grad():
                    # We can only do it here because the input are only switched
                    # to cuda types above.
                    summary(model, input_size=(1, self.args.imageSize, self.args.imageSize))
                self.print_summary = False
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if self.change_lr:
                self.adjust_learning_rate(optimizer)
            if batch_idx % self.args.log_interval == 0:
                pbar.set_description(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data_a), len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                        loss.item()))

        if (self.args.enable_logging):
            logger.log_value('loss', loss.item()).step()

        try:
            os.stat('{}{}'.format(self.args.model_dir,self.suffix))
        except:
            os.makedirs('{}{}'.format(self.args.model_dir,self.suffix))

        torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
                '{}{}/checkpoint_{}.pth'.format(self.args.model_dir,self.suffix,epoch))
コード例 #8
0
ファイル: HardNet_exp.py プロジェクト: areslp/hardnet.pytorch
def train(train_loader, model, optimizer, epoch, logger, load_triplets=False):

    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        if load_triplets:
            data_a, data_p, data_n = data
        else:
            data_a, data_p = data

        if args.cuda:
            data_a, data_p = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a, out_p = model(data_a), model(data_p)

        # load_triplets=Flase for the L2Net and HardNet, these two generate the positive patch based on the batch data
        if load_triplets:
            data_n = data_n.cuda()
            data_n = Variable(data_n)
            out_n = model(data_n)

        # for the comparision with L2Net, and random_global
        if args.batch_reduce == 'L2Net':
            loss = loss_L2Net(out_a,
                              out_p,
                              anchor_swap=args.anchorswap,
                              margin=args.margin,
                              loss_type=args.loss)
        elif args.batch_reduce == 'random_global':
            # using the random nagative patch samples from the dataset
            loss = loss_random_sampling(out_a,
                                        out_p,
                                        out_n,
                                        margin=args.margin,
                                        anchor_swap=args.anchorswap,
                                        loss_type=args.loss)
        else:
            loss = loss_HardNet(out_a,
                                out_p,
                                margin=args.margin,
                                anchor_swap=args.anchorswap,
                                anchor_ave=args.anchorave,
                                batch_reduce=args.batch_reduce,
                                loss_type=args.loss)

        # E2 loss in L2Net for descriptor componet correlation
        if args.decor:
            loss += CorrelationPenaltyLoss()(out_a)

        # gor for HardNet
        if args.gor:
            loss += args.alpha * global_orthogonal_regularization(out_a, out_n)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer, args)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]  Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data[0]))
            if (args.enable_logging):
                logger.log_string(
                    'logs',
                    'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
                        epoch, batch_idx * len(data_a),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.data[0]))

    try:
        os.stat('{}{}'.format(args.model_dir, suffix))
    except:
        os.makedirs('{}{}'.format(args.model_dir, suffix))

    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}{}/checkpoint_{}.pth'.format(args.model_dir, suffix, epoch))