Exemple #1
0
def train(epoch, model, criterion, optimizer, train_loader, args):

    losses = AverageMeter()
    batch_time = AverageMeter()
    accuracy = AverageMeter()
    pos_sims = AverageMeter()
    neg_sims = AverageMeter()

    end = time.time()

    freq = min(args.print_freq, len(train_loader))

    for i, data_ in enumerate(train_loader, 0):

        inputs, labels = data_
        #print(inputs)

        # wrap them in Variable
        inputs = Variable(inputs).cuda()
        labels = Variable(labels).cuda()

        optimizer.zero_grad()

        embed_feat = model(inputs)

        loss = criterion(embed_feat, labels)

        #get_dot = bad_grad_viz.register_hooks(loss)

        if args.orth_reg != 0:
            loss = orth_reg(net=model, loss=loss, cof=args.orth_reg)

        loss.backward()
        optimizer.step()

        #for param in model.parameters():
        #    print(param.grad.size())
        #dot = get_dot()
        #dot.save('tmp.dot')

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        losses.update(loss.item())
        #accuracy.update(inter_)
        #pos_sims.update(dist_ap)
        #neg_sims.update(dist_an)

        if (i + 1) % freq == 0 or (i + 1) == len(train_loader):
            print('Epoch: [{0:03d}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Loss {loss.avg:.8f} \t'.format(epoch + 1,
                                                  i + 1,
                                                  len(train_loader),
                                                  batch_time=batch_time,
                                                  loss=losses))

        if epoch == 0 and i == 0:
            print('-- HA-HA-HA-HA-AH-AH-AH-AH --')
Exemple #2
0
def train(epoch, model, criterion, optimizer, train_loader, args):

    losses = AverageMeter()
    batch_time = AverageMeter()
    accuracy = AverageMeter()
    pos_sims = AverageMeter()
    neg_sims = AverageMeter()

    end = time.time()

    freq = min(args.print_freq, len(train_loader))

    for i, data_ in enumerate(train_loader, 0):

        inputs, labels = data_

        # wrap them in Variable
        inputs = Variable(inputs).cuda()
        labels = Variable(labels).cuda()

        optimizer.zero_grad()

        embed_feat = model(inputs)

        loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels)

        if args.orth_reg != 0:
            loss = orth_reg(net=model, loss=loss, cof=args.orth_reg)

        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        losses.update(loss.item())
        accuracy.update(inter_)
        pos_sims.update(dist_ap)
        neg_sims.update(dist_an)

        if (i + 1) % freq == 0 or (i + 1) == len(train_loader):
            print('Epoch: [{0:03d}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f} \t'
                  'Accuracy {accuracy.avg:.4f} \t'
                  'Pos {pos.avg:.4f}\t'
                  'Neg {neg.avg:.4f} \t'.format(epoch + 1,
                                                i + 1,
                                                len(train_loader),
                                                batch_time=batch_time,
                                                loss=losses,
                                                accuracy=accuracy,
                                                pos=pos_sims,
                                                neg=neg_sims))

        if epoch == 0 and i == 0:
            print('-- HA-HA-HA-HA-AH-AH-AH-AH --')
Exemple #3
0
def train(epoch, model, criterion, optimizer, train_loader, args):

    losses = AverageMeter()
    batch_time = AverageMeter()
    accuracy = AverageMeter()
    pos_sims = AverageMeter()
    neg_sims = AverageMeter()

    end = time.time()

    freq = min(args.print_freq, len(train_loader))

    for i, data_ in enumerate(train_loader, 0):

        inputs, poss, negs, labels = data_

        inputs = inputs.cuda()
        labels = labels.cuda()
        poss = poss.cuda()
        negs = negs.cuda()

        optimizer.zero_grad()

        anchor_emb = model(inputs)
        poss_emb = model(poss)
        negs_emb = model(negs)

        loss = criterion(anchor_emb, poss_emb, negs_emb)

        if args.orth_reg != 0:
            loss = orth_reg(net=model, loss=loss, cof=args.orth_reg)

        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        losses.update(loss.data.item())


        if (i + 1) % freq == 0 or (i+1) == len(train_loader):
            print('Epoch: [{0:03d}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f} \t'.format
                  (epoch + 1, i + 1, len(train_loader), batch_time=batch_time,
                   loss=losses,))

        if epoch == 0 and i == 0:
            print('-- HA-HA-HA-HA-AH-AH-AH-AH --')
Exemple #4
0
        # get the inputs
        inputs, labels = data
        # break
        # wrap them in Variable
        inputs = Variable(inputs.cuda())
        labels = Variable(labels).cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        embed_feat = model(inputs)

        # loss = criterion(embed_feat, labels)
        loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels)
        if args.orth_cof > 1e-9:
            loss = orth_reg(model, loss, cof=args.orth_cof)
        loss.backward()
        optimizer.step()
        running_loss += loss.data[0]
    # print(epoch)
    print(
        '[epoch %05d]\t loss: %.7f \t prec: %.3f \t pos-dist: %.3f \tneg-dist: %.3f'
        % (epoch + 1, running_loss, inter_, dist_ap, dist_an))
    if epoch % args.save_step == 0:
        torch.save(model, os.path.join(log_dir, '%d_model.pkl' % epoch))

torch.save(model, os.path.join(log_dir, '%d_model.pkl' % epoch))

print('Finished Training')
Exemple #5
0
def train(epoch, model, criterion, optimizer, train_loader, args):
    losses = AverageMeter()
    batch_time = AverageMeter()
    accuracy = AverageMeter()
    pos_sims = AverageMeter()
    neg_sims = AverageMeter()

    end = time.time()

    freq = min(args.print_freq, len(train_loader))
    if not args.use_test or True:
        test_loader = [(0, 0) for _ in range(len(train_loader))]
    else:
        test_loader = args.use_test
    for i, (data_, test_data_) in enumerate(zip(train_loader, test_loader), 0):
        inputs, labels = data_
        inputs_test, _ = test_data_
        num_samples, _, w, h = inputs.size()
        
        inputs_1 = inputs[:, 0:3, :, :]
        inputs_2 = inputs[np.random.choice(range(num_samples), args.rot_batch), :, :, :].view(-1, 3, w, h)
        if args.use_test and False:
            num_samples = inputs_test.size(0)
            inputs_3 = inputs_test[np.random.choice(range(num_samples), args.rot_batch), :, :, :].view(-1, 3, w, h)
            inputs_3 = Variable(inputs_3).cuda()
        # wrap them in Variable

        inputs_1 = Variable(inputs_1).cuda()
        inputs_2 = Variable(inputs_2).cuda()
        labels = Variable(labels).cuda()
        
        optimizer.zero_grad()


        

        if not args.rot_only:
            embed_feat = model(inputs_1, rot=False)
            if args.dim % 64 != 0:
                loss, inter_, dist_ap, dist_an = nn.CrossEntropyLoss()(embed_feat, labels), 0, 0, 0
            else:
                loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels)
        else:
            loss, inter_, dist_ap, dist_an = 0, 0, 0, 0

        loss_rot = torch.zeros(1)
        loss_rot_test = torch.zeros(1)
        if args.self_supervision_rot:
            score = model(inputs_2, rot=True)
            labels_rot = torch.LongTensor([0, 1, 2, 3] * args.rot_batch).cuda()
            loss_rot = nn.CrossEntropyLoss()(score, labels_rot)
            loss += args.self_supervision_rot * loss_rot


            


        if args.orth_reg != 0:
            loss = orth_reg(net=model, loss=loss, cof=args.orth_reg)

        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if not args.rot_only:
            losses.update(loss.item())
            accuracy.update(inter_)
            pos_sims.update(dist_ap)
            neg_sims.update(dist_an)

        
        if (i + 1) % freq == 0 or (i+1) == len(train_loader):
            print('Epoch: [{0:03d}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f} \t'
                  'Loss_rot {loss_rot:.4f} \t'
                  'Loss_rot_test {loss_rot_test:.4f} \t'
                  'accuracy {accuracy.avg:.4f} \t'
                  'Pos {pos.avg:.4f}\t'
                  'Neg {neg.avg:.4f} \t'.format
                  (epoch + 1, i + 1, len(train_loader), batch_time=batch_time,
                   loss=losses, loss_rot=loss_rot.item(), loss_rot_test=loss_rot_test.item(), accuracy=accuracy, pos=pos_sims, neg=neg_sims))
Exemple #6
0
def main(args):

    #  训练日志保存
    log_dir = os.path.join('checkpoints', args.log_dir)
    mkdir_if_missing(log_dir)

    sys.stdout = logging.Logger(os.path.join(log_dir, 'log.txt'))
    display(args)

    if args.r is None:
        model = models.create(args.net, Embed_dim=args.dim)
        # load part of the model
        model_dict = model.state_dict()
        # print(model_dict)
        if args.net == 'bn':
            pretrained_dict = torch.load(
                'pretrained_models/bn_inception-239d2248.pth')
        else:
            pretrained_dict = torch.load(
                'pretrained_models/inception_v3_google-1a9a5a14.pth')

        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }

        model_dict.update(pretrained_dict)

        # orth init
        if args.init == 'orth':
            print('initialize the FC layer orthogonally')
            _, _, v = torch.svd(model_dict['Embed.linear.weight'])
            model_dict['Embed.linear.weight'] = v.t()

        # zero bias
        model_dict['Embed.linear.bias'] = torch.zeros(args.dim)

        model.load_state_dict(model_dict)
    else:
        # resume model
        model = torch.load(args.r)

    model = model.cuda()

    torch.save(model, os.path.join(log_dir, 'model.pkl'))
    print('initial model is save at %s' % log_dir)

    # fine tune the model: the learning rate for pre-trained parameter is 1/10
    new_param_ids = set(map(id, model.Embed.parameters()))

    new_params = [p for p in model.parameters() if id(p) in new_param_ids]

    base_params = [p for p in model.parameters() if id(p) not in new_param_ids]
    param_groups = [{
        'params': base_params,
        'lr_mult': 0.1
    }, {
        'params': new_params,
        'lr_mult': 1.0
    }]

    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = losses.create(args.loss, alpha=args.alpha, k=args.k).cuda()

    data = DataSet.create(args.data, root=None, test=False)
    train_loader = torch.utils.data.DataLoader(
        data.train,
        batch_size=args.BatchSize,
        sampler=RandomIdentitySampler(data.train,
                                      num_instances=args.num_instances),
        drop_last=True,
        num_workers=args.nThreads)

    for epoch in range(args.start, args.epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            # wrap them in Variable
            inputs = Variable(inputs.cuda())
            labels = Variable(labels).cuda()

            optimizer.zero_grad()

            embed_feat = model(inputs)

            loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels)
            if args.orth > 0:
                loss = orth_reg(model, loss, cof=args.orth)
            loss.backward()
            optimizer.step()
            running_loss += loss.data[0]
            if epoch == 0 and i == 0:
                print(50 * '#')
                print('Train Begin -- HA-HA-HA')

        print(
            '[Epoch %05d]\t Loss: %.3f \t Accuracy: %.3f \t Pos-Dist: %.3f \t Neg-Dist: %.3f'
            % (epoch + 1, running_loss, inter_, dist_ap, dist_an))

        if epoch % args.save_step == 0:
            torch.save(model, os.path.join(log_dir, '%d_model.pkl' % epoch))