def train(train_loader, model, optimizer, epoch, logger):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, (data_a, data_p, data_n) in pbar:

        if args.cuda:
            data_a, data_p, data_n = data_a.cuda(), data_p.cuda(), data_n.cuda(
            )

        data_a, data_p, data_n = Variable(data_a), Variable(data_p), Variable(
            data_n)

        out_a, out_p, out_n = model(data_a), model(data_p), model(data_n)

        #hardnet loss
        loss = loss_random_sampling(out_a, out_p, out_n, margin=args.margin)

        if args.decor:
            loss += CorrelationPenaltyLoss()(out_a)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if (logger != None):
            logger.log_value('loss', loss.data[0]).step()

        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data[0]))

    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
def train(train_loader, model, optimizer, epoch, logger, load_triplets  = False):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        if load_triplets:
            data_a, data_p, data_n = data
        else:
            data_a, data_p = data

        if args.cuda:
            data_a, data_p  = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a = model(data_a)
            out_p = model(data_p)
        if load_triplets:
            data_n  = data_n.cuda()
            data_n = Variable(data_n)
            out_n = model(data_n)

        if args.batch_reduce == 'L2Net':
            loss = loss_L2Net(out_a, out_p, anchor_swap = args.anchorswap,
                    margin = args.margin, loss_type = args.loss)
        elif args.batch_reduce == 'random_global':
            loss = loss_random_sampling(out_a, out_p, out_n,
                margin=args.margin,
                anchor_swap=args.anchorswap,
                loss_type = args.loss)
        else:
            loss = loss_HardNet(out_a, out_p,
                            margin=args.margin,
                            anchor_swap=args.anchorswap,
                            anchor_ave=args.anchorave,
                            batch_reduce = args.batch_reduce,
                            loss_type = args.loss)

        if args.decor:
            loss += CorrelationPenaltyLoss()(out_a)
            
        if args.gor:
            loss += args.alpha*global_orthogonal_regularization(out_a, out_n)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                    loss.item()))

    if (args.enable_logging):
        logger.log_value('loss', loss.item()).step()

    try:
        os.stat('{}{}'.format(args.model_dir,suffix))
    except:
        os.makedirs('{}{}'.format(args.model_dir,suffix))

    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}{}/checkpoint_{}.pth'.format(args.model_dir,suffix,epoch))
    del loss, data_p, data_a, data, out_a, out_p
Exemple #3
0
def train(test_loader, train_loader, model, optimizer, epoch, logger, load_triplets  = False):
    # switch to train mode

    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        if load_triplets:
            data_a, data_p, data_n = data
        else:
            data_a, data_p = data
        # print("data_a",data_a.size())
        if args.cuda:
            data_a, data_p  = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a, out_p = model(data_a), model(data_p)

        if load_triplets:
            data_n  = data_n.cuda()
            data_n = Variable(data_n)
            out_n = model(data_n)

        if args.batch_reduce == 'L2Net':
            loss = loss_L2Net(out_a, out_p, anchor_swap = args.anchorswap,
                    margin = args.margin, loss_type = args.loss)
        elif args.batch_reduce == 'random_global':
            loss = loss_random_sampling(out_a, out_p, out_n,
                margin=args.margin,
                anchor_swap=args.anchorswap,
                loss_type = args.loss)
        else:
            loss = loss_HardNet(out_a, out_p,
                            margin=args.margin,
                            anchor_swap=args.anchorswap,
                            anchor_ave=args.anchorave,
                            batch_reduce = args.batch_reduce,
                            loss_type = args.loss)

        if args.decor:
            loss += args.cor_weights * CorrelationPenaltyLoss()(out_a)
        if args.gor:
            loss += args.alpha * global_orthogonal_regularization(out_a, out_n)
        if args.evendis:
            loss += args.even_weights * Even_distributeLoss()(out_a)
        if args.quan:
            loss += args.quan_weights * QuantilizeLoss(args.quan_scale)(out_a)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not args.constantlr:
            adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]lr:{:f} \tLoss_T: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                           optimizer.param_groups[0]['lr'], loss.data[0]))

    if (args.enable_logging):
        logger.log_value('loss', loss.data.item()).step()

    try:
        os.stat('{}{}'.format(args.model_dir,suffix))
    except:
        os.makedirs('{}{}'.format(args.model_dir,suffix))

    torch.save({'epoch': epoch + 1, 'optimizer':optimizer.state_dict()
                ,'state_dict': model.state_dict()},
               '{}{}/checkpoint_{}{}.pth'.format(args.model_dir,suffix,newstart,epoch))
    # torch.save(model,'{}{}/checkpoint_{}.pth'.format(args.model_dir,suffix,epoch))
    print("model {}{}/checkpoint_{}{}.pth is saved".format(args.model_dir,suffix,newstart,epoch))
    if (args.enable_logging):
        logger.log_value(test_loader['name']+'loss is:', loss.data[0])
    return loss.data.item()
Exemple #4
0
    def train(self, train_loader, model, optimizer, epoch, logger, load_triplets  = False):
        print("Training model")
        # switch to train mode
        model.train()
        pbar = tqdm(enumerate(train_loader))
        for batch_idx, data in pbar:
            if load_triplets:
                data_a, data_p, data_n = data
            else:
                data_a, data_p = data

            if self.args.cuda:
                data_a, data_p  = data_a.cuda(), data_p.cuda()
                data_a, data_p = Variable(data_a), Variable(data_p)
            out_a = model(data_a)
            out_p = model(data_p)
            if load_triplets:
                data_n  = data_n.cuda()
                data_n = Variable(data_n)
                out_n = model(data_n)

            if self.args.loss == 'qht':
                loss = loss_SOSNet(out_a, out_p,
                                   batch_reduce=self.args.batch_reduce,
                                   no_cuda=self.args.no_cuda)
            else:
                if self.args.batch_reduce == 'L2Net':
                    loss = loss_L2Net(out_a, out_p, anchor_swap = self.args.anchorswap,
                            margin = self.args.margin, loss_type = self.args.loss)
                elif self.args.batch_reduce == 'random_global':
                    loss = loss_random_sampling(out_a, out_p, out_n,
                        margin=self.args.margin,
                        anchor_swap=self.args.anchorswap,
                        loss_type = self.args.loss)
                else:
                    loss = loss_HardNet(out_a, out_p,
                                    margin=self.args.margin,
                                    anchor_swap=self.args.anchorswap,
                                    anchor_ave=self.args.anchorave,
                                    batch_reduce = self.args.batch_reduce,
                                    loss_type = self.args.loss,
                                    no_cuda = self.args.no_cuda)

            if self.args.decor:
                loss += CorrelationPenaltyLoss()(out_a)
                
            if self.args.gor:
                loss += self.args.alpha*global_orthogonal_regularization(out_a, out_n)
            
            if self.print_summary:
                with torch.no_grad():
                    # We can only do it here because the input are only switched
                    # to cuda types above.
                    summary(model, input_size=(1, self.args.imageSize, self.args.imageSize))
                self.print_summary = False
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if self.change_lr:
                self.adjust_learning_rate(optimizer)
            if batch_idx % self.args.log_interval == 0:
                pbar.set_description(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data_a), len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                        loss.item()))

        if (self.args.enable_logging):
            logger.log_value('loss', loss.item()).step()

        try:
            os.stat('{}{}'.format(self.args.model_dir,self.suffix))
        except:
            os.makedirs('{}{}'.format(self.args.model_dir,self.suffix))

        torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
                '{}{}/checkpoint_{}.pth'.format(self.args.model_dir,self.suffix,epoch))
Exemple #5
0
def train(train_loader, model, optimizer, epoch, logger, load_triplets=False):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        #print( data)
        if load_triplets:
            data_a, data_p, data_n = data
            if args.cuda:
                data_a, data_p, data_n = data_a.cuda(), data_p.cuda(
                ), data_n.cuda()
            data_a, data_p, data_n = Variable(data_a), Variable(
                data_p), Variable(data_n)
            out_a, out_p, out_n = model(data_a), model(data_p), model(data_n)
            loss = loss_random_sampling(out_a,
                                        out_p,
                                        out_n,
                                        margin=args.margin,
                                        anchor_swap=args.anchorswap,
                                        loss_type=args.loss)
        else:
            data_a, data_p = data
            if args.cuda:
                data_a, data_p = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a, out_p = model(data_a), model(data_p)
            #hardnet loss
            if args.batch_reduce == 'L2Net':
                loss = loss_L2Net(out_a,
                                  out_p,
                                  column_row_swap=True,
                                  anchor_swap=args.anchorswap,
                                  margin=args.margin,
                                  loss_type=args.loss)
            else:
                loss = loss_HardNet(out_a,
                                    out_p,
                                    margin=args.margin,
                                    column_row_swap=True,
                                    anchor_swap=args.anchorswap,
                                    anchor_ave=args.anchorave,
                                    batch_reduce=args.batch_reduce,
                                    loss_type=args.loss)
        if args.decor:
            loss += CorrelationPenaltyLoss()(out_a)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if (args.enable_logging):
            logger.log_value('loss', loss.data[0]).step()

        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data[0]))

    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
def train(train_loader, model, optimizer, epoch, logger, load_triplets=False):

    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        if load_triplets:
            data_a, data_p, data_n = data
        else:
            data_a, data_p = data

        if args.cuda:
            data_a, data_p = data_a.cuda(), data_p.cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
            out_a, out_p = model(data_a), model(data_p)

        # load_triplets=Flase for the L2Net and HardNet, these two generate the positive patch based on the batch data
        if load_triplets:
            data_n = data_n.cuda()
            data_n = Variable(data_n)
            out_n = model(data_n)

        # for the comparision with L2Net, and random_global
        if args.batch_reduce == 'L2Net':
            loss = loss_L2Net(out_a,
                              out_p,
                              anchor_swap=args.anchorswap,
                              margin=args.margin,
                              loss_type=args.loss)
        elif args.batch_reduce == 'random_global':
            # using the random nagative patch samples from the dataset
            loss = loss_random_sampling(out_a,
                                        out_p,
                                        out_n,
                                        margin=args.margin,
                                        anchor_swap=args.anchorswap,
                                        loss_type=args.loss)
        else:
            loss = loss_HardNet(out_a,
                                out_p,
                                margin=args.margin,
                                anchor_swap=args.anchorswap,
                                anchor_ave=args.anchorave,
                                batch_reduce=args.batch_reduce,
                                loss_type=args.loss)

        # E2 loss in L2Net for descriptor componet correlation
        if args.decor:
            loss += CorrelationPenaltyLoss()(out_a)

        # gor for HardNet
        if args.gor:
            loss += args.alpha * global_orthogonal_regularization(out_a, out_n)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer, args)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]  Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data[0]))
            if (args.enable_logging):
                logger.log_string(
                    'logs',
                    'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
                        epoch, batch_idx * len(data_a),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.data[0]))

    try:
        os.stat('{}{}'.format(args.model_dir, suffix))
    except:
        os.makedirs('{}{}'.format(args.model_dir, suffix))

    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}{}/checkpoint_{}.pth'.format(args.model_dir, suffix, epoch))