コード例 #1
0
def train(args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model, optimizer = get_model(args)
    model.train()
    train_loader, _ = data_loader(args)

    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        config = json.dumps(vars(args), indent=4, separators=(',', ':'))
        fw.write(config)
        fw.write('#epoch,loss,pred@1,pred@5\n')

    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch
    end = time.time()
    max_iter = total_epoch * len(train_loader)
    print('Max iter:', max_iter)
    while current_epoch < total_epoch:
        model.train()
        losses.reset()
        top1.reset()
        top5.reset()
        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)

        if res:
            for g in optimizer.param_groups:
                out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr'])
                fw.write(out_str)

        steps_per_epoch = len(train_loader)
        for idx, dat in enumerate(train_loader):
            img_path, img, label = dat
            global_counter += 1
            img, label = img.cuda(), label.cuda()
            img_var, label_var = Variable(img), Variable(label)

            logits = model(img_var, label_var)
            loss_val, = model.module.get_loss(logits, label_var)

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            if not args.onehot == 'True':
                logits1 = torch.squeeze(logits[0])
                prec1_1, prec5_1 = Metrics.accuracy(logits1.data,
                                                    label.long(),
                                                    topk=(1, 5))
                top1.update(prec1_1[0], img.size()[0])
                top5.update(prec5_1[0], img.size()[0])

            losses.update(loss_val.data[0], img.size()[0])
            batch_time.update(time.time() - end)

            end = time.time()
            if global_counter % 1000 == 0:
                losses.reset()
                top1.reset()
                top5.reset()

            if global_counter % args.disp_interval == 0:
                # Calculate ETA
                eta_seconds = (
                    (total_epoch - current_epoch) * steps_per_epoch +
                    (steps_per_epoch - idx)) * batch_time.avg
                eta_str = "{:0>8}".format(
                    datetime.timedelta(seconds=int(eta_seconds)))
                eta_seconds_epoch = steps_per_epoch * batch_time.avg
                eta_str_epoch = "{:0>8}".format(
                    datetime.timedelta(seconds=int(eta_seconds_epoch)))
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'ETA {eta_str}({eta_str_epoch})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        current_epoch,
                        global_counter % len(train_loader),
                        len(train_loader),
                        batch_time=batch_time,
                        eta_str=eta_str,
                        eta_str_epoch=eta_str_epoch,
                        loss=losses,
                        top1=top1,
                        top5=top5))

        if current_epoch % 1 == 0:
            save_checkpoint(args, {
                'epoch': current_epoch,
                'arch': 'resnet',
                'global_counter': global_counter,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
                            is_best=False,
                            filename='%s_epoch_%d_glo_step_%d.pth.tar' %
                            (args.dataset, current_epoch, global_counter))

        with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                  'a') as fw:
            fw.write('%d,%.4f,%.3f,%.3f\n' %
                     (current_epoch, losses.avg, top1.avg, top5.avg))

        current_epoch += 1
コード例 #2
0
def train(args):
    batch_time = AverageMeter()
    losses = AverageMeter()

    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch

    train_loader = train_data_loader_iam(args)
    max_step = total_epoch * len(train_loader)
    args.max_step = max_step
    print('Max step:', max_step)

    model, optimizer, criterion = get_model(args)
    print(model)
    model.train()
    end = time.time()

    while current_epoch < total_epoch:
        model.train()
        losses.reset()
        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)
        steps_per_epoch = len(train_loader)

        for idx, dat in enumerate(train_loader):
            img, label = dat
            label = label.cuda(non_blocking=True)
            logits = model(img)

            if len(logits.shape) == 1:
                logits = logits.reshape(label.shape)
            loss_val = criterion(logits, label)

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            losses.update(loss_val.data.item(), img.size()[0])
            batch_time.update(time.time() - end)
            end = time.time()

            global_counter += 1
            if global_counter % 1000 == 0:
                losses.reset()

            if global_counter % args.disp_interval == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'LR: {:.5f}\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          current_epoch,
                          global_counter % len(train_loader),
                          len(train_loader),
                          optimizer.param_groups[0]['lr'],
                          loss=losses))

        if current_epoch == args.epoch - 1:
            save_checkpoint(args, {
                'epoch': current_epoch,
                'global_counter': global_counter,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
                            is_best=False,
                            filename='%s_epoch_%d.pth' %
                            (args.dataset, current_epoch))
        current_epoch += 1
コード例 #3
0
def train(args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    
    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch

    train_loader, val_loader = train_data_loader(args)
    max_step = total_epoch*len(train_loader)
    args.max_step = max_step 
    print('Max step:', max_step)
    
    model, optimizer = get_model(args)
    
    model.train()
    print(model)
    end = time.time()

    while current_epoch < total_epoch:
        model.train()
        losses.reset()
        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)
        steps_per_epoch = len(train_loader)

        index = 0  
        for idx, dat in enumerate(train_loader):
            
            img_name1, img1, label1, img_name2, img2, label2, img_name3, img3, label3 = dat
            label1 = label1.cuda(non_blocking=True)
            label2 = label2.cuda(non_blocking=True)
            label3 = label3.cuda(non_blocking=True)            
            
            x11, x1, x22,x2, x33,x3 = model(img1, img2, img3, current_epoch, label1, index)
            index += 1

            loss_train = 0.4 * (F.multilabel_soft_margin_loss(x11, label1) + F.multilabel_soft_margin_loss(x22, label2)
                    + F.multilabel_soft_margin_loss(x33, label3)) + (F.multilabel_soft_margin_loss(x1, label1)
                    + F.multilabel_soft_margin_loss(x2, label2) + F.multilabel_soft_margin_loss(x3, label3))

            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()

            losses.update(loss_train.data.item(), img.size()[0])
            batch_time.update(time.time() - end)
            end = time.time()
            
            global_counter += 1
            if global_counter % 1000 == 0:
                losses.reset()

            if global_counter % args.disp_interval == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'LR: {:.5f}\t' 
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                        current_epoch, global_counter%len(train_loader), len(train_loader), 
                        optimizer.param_groups[0]['lr'], loss=losses))

        if current_epoch == args.epoch-1:
            save_checkpoint(args,
                        {
                            'epoch': current_epoch,
                            'global_counter': global_counter,
                            'state_dict':model.state_dict(),
                            'optimizer':optimizer.state_dict()
                        }, is_best=False,
                        filename='%s_epoch_%d.pth' %(args.dataset, current_epoch))
        current_epoch += 1
コード例 #4
0
def train(args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    loss_cls = AverageMeter()
    loss_dist = AverageMeter()
    loss_aux = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model, optimizer = get_model(args)
    model.train()
    train_loader, _ = data_loader(args)

    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        config = json.dumps(vars(args), indent=4, separators=(',', ':'))
        fw.write(config)
        fw.write('#epoch,loss,pred@1,pred@5\n')

    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch
    end = time.time()
    max_iter = total_epoch * len(train_loader)
    print('Max iter:', max_iter)

    while current_epoch < total_epoch:
        model.train()
        losses.reset()
        loss_cls.reset()
        loss_dist.reset()
        loss_aux.reset()
        top1.reset()
        top5.reset()
        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)

        if res:
            for g in optimizer.param_groups:
                out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr'])
                with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                          'a') as fw:
                    fw.write(out_str)

        steps_per_epoch = len(train_loader)
        for idx, dat in enumerate(train_loader):
            img_path, img, label = dat
            global_counter += 1
            img, label = img.cuda(), label.cuda()
            img_var, label_var = Variable(img), Variable(label)

            logits = model(img_var, label_var)
            loss_list = model.module.get_loss(logits, label_var)
            loss_val = loss_list[0]

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            losses.update(loss_val.data.item(), img.size(0))
            loss_cls.update(loss_list[1].data.item(), img.size(0))
            loss_dist.update(loss_list[2].data.item(), img.size(0))
            loss_aux.update(loss_list[3].data.item(), img.size(0))
            batch_time.update(time.time() - end)

            end = time.time()
            if global_counter % 1000 == 0:
                losses.reset()
                top1.reset()
                top5.reset()
                loss_cls.reset()
                loss_dist.reset()
                loss_aux.reset()
                batch_time.reset()

            if global_counter % args.disp_interval == 0:
                # Calculate ETA
                # eta_seconds = ((total_epoch - current_epoch)*steps_per_epoch + (steps_per_epoch - idx))*batch_time.avg
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Loss_cls {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
                      'Loss_dist {loss_d.val:.4f} ({loss_d.avg:.4f})\t'
                      'Loss_aux {loss_aux.val:.4f} ({loss_aux.avg:.4f})\t'.
                      format(current_epoch,
                             global_counter % len(train_loader),
                             len(train_loader),
                             batch_time=batch_time,
                             loss=losses,
                             loss_c=loss_cls,
                             loss_d=loss_dist,
                             loss_aux=loss_aux))

        if current_epoch % args.save_interval == 0:
            model_stat_dict = model.module.state_dict()
            save_checkpoint(args, {
                'epoch': current_epoch,
                'arch': 'resnet',
                'global_counter': global_counter,
                'state_dict': model_stat_dict,
                'optimizer': optimizer.state_dict(),
                'center_feat_bank': model.module.center_feat_bank
            },
                            is_best=False,
                            filename='%s_epoch_%d_glo_step_%d.pth' %
                            (args.dataset, current_epoch, global_counter))

        with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                  'a') as fw:
            fw.write('%d,%.4f,%.3f,%.3f\n' %
                     (current_epoch, losses.avg, top1.avg, top5.avg))

        current_epoch += 1
コード例 #5
0
def train(args):
    batch_time = AverageMeter()
    lossCos = AverageMeter()
    losses = AverageMeter()
    loss_root = AverageMeter()
    loss_parent = AverageMeter()
    loss_child = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_parent = AverageMeter()
    top5_parent = AverageMeter()
    top1_root = AverageMeter()
    top5_root = AverageMeter()
    model, optimizer = get_model(args)
    model.train()
    train_loader, _, _ = data_loader(args)

    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        config = json.dumps(vars(args), indent=4, separators=(',', ':'))
        fw.write(config)
        fw.write('#epoch \t loss \t pred@1 \t pred@5\n')

    # construct writer
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    writer = SummaryWriter(log_dir=args.log_dir)

    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch
    end = time.time()
    max_iter = total_epoch * len(train_loader)
    print('Max iter:', max_iter)
    while current_epoch < total_epoch:
        model.train()
        lossCos.reset()
        losses.reset()
        loss_root.reset()
        loss_parent.reset()
        loss_child.reset()
        top1.reset()
        top5.reset()
        top1_parent.reset()
        top5_parent.reset()
        top1_root.reset()
        top5_root.reset()

        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)

        if res:
            with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                      'a') as fw:
                for g in optimizer.param_groups:
                    out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr'])
                    fw.write(out_str)

        steps_per_epoch = len(train_loader)
        for idx, dat in enumerate(train_loader):
            img_path, img, label = dat
            global_counter += 1
            img, root_label, parent_label, child_label = img.cuda(
            ), label[0].cuda(), label[1].cuda(), label[2].cuda()
            img_var, root_label_var, parent_label_var, child_label_var = Variable(
                img), Variable(root_label), Variable(parent_label), Variable(
                    child_label)

            logits = model(img_var)
            loss_val, loss_root_val, loss_parent_val, loss_child_val, lossCos_val = model.module.get_loss(
                logits, root_label_var, parent_label_var, child_label_var)

            # write into tensorboard
            writer.add_scalar('loss_val', loss_val, global_counter)

            # network parameter update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            if not args.onehot == 'True':
                logits5 = torch.squeeze(logits[-1])
                prec1, prec5 = evaluate.accuracy(logits5.data,
                                                 child_label.long(),
                                                 topk=(1, 5))
                top1.update(prec1[0], img.size()[0])
                top5.update(prec5[0], img.size()[0])
                logits4 = torch.squeeze(logits[-2])
                prec1_4, prec5_4 = evaluate.accuracy(logits4.data,
                                                     parent_label.long(),
                                                     topk=(1, 5))
                top1_parent.update(prec1_4[0], img.size()[0])
                top5_parent.update(prec5_4[0], img.size()[0])
                logits3 = torch.squeeze(logits[-3])
                prec1_3, prec5_3 = evaluate.accuracy(logits3.data,
                                                     root_label.long(),
                                                     topk=(1, 5))
                top1_root.update(prec1_3[0], img.size()[0])
                top5_root.update(prec5_3[0], img.size()[0])

            losses.update(loss_val.data, img.size()[0])
            loss_root.update(loss_root_val.data, img.size()[0])
            loss_parent.update(loss_parent_val.data, img.size()[0])
            loss_child.update(loss_child_val.data, img.size()[0])
            lossCos.update(lossCos_val.data, img.size()[0])
            batch_time.update(time.time() - end)

            end = time.time()

            if global_counter % args.disp_interval == 0:
                # Calculate ETA
                eta_seconds = (
                    (total_epoch - current_epoch) * steps_per_epoch +
                    (steps_per_epoch - idx)) * batch_time.avg
                eta_str = "{:0>8}".format(
                    datetime.timedelta(seconds=int(eta_seconds)))
                eta_seconds_epoch = steps_per_epoch * batch_time.avg
                eta_str_epoch = "{:0>8}".format(
                    datetime.timedelta(seconds=int(eta_seconds_epoch)))
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'ETA {eta_str}({eta_str_epoch})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Loss cos {lossCos.val:.4f} ({lossCos.avg:.4f})\t'
                    'Loss root {loss_parent.val:.4f} ({loss_root.avg:.4f})\t'
                    'Loss parent {loss_parent.val:.4f} ({loss_parent.avg:.4f})\t'
                    'Loss child {loss_child.val:.4f} ({loss_child.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                    'parent: Prec@1 {top1_parent.val:.3f} ({top1_parent.avg:.3f})\t'
                    'Prec@5 {top5_parent.val:.3f} ({top5_parent.avg:.3f})\t'
                    'root: Prec@1 {top1_root.val:.3f} ({top1_root.avg:.3f})\t'
                    'Prec@5 {top5_root.val:.3f} ({top5_root.avg:.3f})'.format(
                        current_epoch,
                        global_counter % len(train_loader),
                        len(train_loader),
                        batch_time=batch_time,
                        eta_str=eta_str,
                        eta_str_epoch=eta_str_epoch,
                        loss=losses,
                        loss_root=loss_root,
                        loss_parent=loss_parent,
                        loss_child=loss_child,
                        top1=top1,
                        top5=top5,
                        top1_parent=top1_parent,
                        top5_parent=top5_parent,
                        top1_root=top1_root,
                        top5_root=top5_root,
                        lossCos=lossCos,
                    ))

        plotter.plot('rootLoss', 'train', current_epoch, loss_root.avg)
        plotter.plot('childLoss', 'train', current_epoch, loss_child.avg)
        plotter.plot('parentLoss', 'train', current_epoch, loss_parent.avg)
        plotter.plot('cosLoss', 'train', current_epoch, lossCos.avg)
        plotter.plot('top1', 'train', current_epoch, top1.avg)
        plotter.plot('top5', 'train', current_epoch, top5.avg)
        plotter.plot('parent Top1', 'train', current_epoch, top1_parent.avg)
        plotter.plot('parent Top5', 'train', current_epoch, top5_parent.avg)
        plotter.plot('root Top1', 'train', current_epoch, top1_root.avg)
        plotter.plot('root Top5', 'train', current_epoch, top5_root.avg)

        current_epoch += 1
        if current_epoch % 50 == 0:
            save_checkpoint(args, {
                'epoch': current_epoch,
                'arch': 'resnet',
                'global_counter': global_counter,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
                            is_best=False,
                            filename='%s_epoch_%d_glo_step_%d.pth.tar' %
                            (args.dataset, current_epoch, global_counter))

        with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                  'a') as fw:
            fw.write(
                '%d \t %.4f \t  %.4f \t  %.4f \t %.4f \t  %.4f \t %.3f \t %.3f\t %.3f \t %.3f\t %.3f \t %.3f\n'
                %
                (current_epoch, losses.avg, lossCos.avg, loss_root.avg,
                 loss_parent.avg, loss_child.avg, top1_root.avg, top5_root.avg,
                 top1_parent.avg, top5_parent.avg, top1.avg, top5.avg))

        losses.reset()
        loss_root.reset()
        loss_parent.reset()
        loss_child.reset()
        top1.reset()
        top5.reset()
        top1_parent.reset()
        top5_parent.reset()
        top1_root.reset()
        top5_root.reset()
        lossCos.reset()
コード例 #6
0
ファイル: train.py プロジェクト: won-bae/MCIS_wsss
def train(args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()
    losses2_1 = AverageMeter()
    losses2_2 = AverageMeter()
    losses3_1 = AverageMeter()
    losses3_2 = AverageMeter()
    losses4_1 = AverageMeter()
    losses4_2 = AverageMeter()
    losses1_comple = AverageMeter()
    losses2_comple = AverageMeter()

    total_epoch = args.epoch

    train_loader, val_loader = train_data_loader_siamese_more_augumentation(
        args)
    # train_loader, val_loader = train_data_loader_normal_resize(args)
    max_step = total_epoch * len(train_loader)
    args.max_step = max_step
    print('Max step:', max_step)

    model, optimizer = get_model(args)
    print(model)

    global_counter = args.global_counter
    print("here: ", global_counter)
    current_epoch = args.current_epoch

    model.train()
    end = time.time()

    while current_epoch < total_epoch:

        losses.reset()
        losses1.reset()
        losses2.reset()
        losses2_1.reset()
        losses2_2.reset()
        losses3_1.reset()
        losses3_2.reset()
        losses4_1.reset()
        losses4_2.reset()

        losses1_comple.reset()
        losses2_comple.reset()

        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)
        steps_per_epoch = len(train_loader)

        validate(model, val_loader)
        model.train()  ## prepare for training
        index = 0
        for idx, dat in enumerate(train_loader):
            _, _, input1, input2, input1_transforms, label1, label2 = dat

            # print(type(input1_transforms),len(input1_transforms),input1_transforms[0].size())
            if random.random() < 0.0:
                # print(input1.size())
                input1 = hide_patch(input1)
                input2 = hide_patch(input2)
                input1_transforms = [hide_patch(i) for i in input1_transforms]

            img = [input1, input2]
            label = torch.cat([label1, label2])

            img2 = [input1, input1_transforms[0]]
            img3 = [input1, input1_transforms[1]]
            img4 = [input1, input1_transforms[2]]

            # print(input1.size(),input2.size(),img.size())
            # print(torch.max(input1),torch.min(input1))

            # print(label.size(),img.size())

            # label = label.cuda(non_blocking=True)
            # label1 = label1.cuda(non_blocking=True)
            # label2 = label2.cuda(non_blocking=True)
            label = label.cuda()
            label1 = label1.cuda()
            label2 = label2.cuda()

            label_new = label1 + label2
            label_new[label_new != 2] = 0
            label_new[label_new == 2] = 1

            label1_comple = label1 - label_new
            label2_comple = label2 - label_new

            assert (label1_comple >= 0).all() and (label2_comple >= 0).all()

            label_new = torch.cat([label_new, label_new])

            # print(label1[0],label2[0],label_new[0])

            logits, co_logits = model(img, current_epoch, label, None)
            logits2, co_logits2 = model(img2, current_epoch, label, None)
            logits3, co_logits3 = model(img3, current_epoch, label, None)
            logits4, co_logits4 = model(img4, current_epoch, label, None)

            index += args.batch_size

            if logits is None:
                print("here")
                continue

            if len(logits.shape) == 1:
                logits = logits.reshape(label.shape)

            # print(logits.size(),label.size(),img.size())
            # loss_val1 = F.multilabel_soft_margin_loss(logits[:input1.size(0)], label[:input1.size(0)])
            loss_val1 = F.multilabel_soft_margin_loss(logits, label)
            loss_val2 = F.multilabel_soft_margin_loss(
                co_logits[:2 * input1.size(0)], label_new)

            loss_val1_comple = F.multilabel_soft_margin_loss(
                co_logits[2 * input1.size(0):3 * input1.size(0)],
                label1_comple)
            loss_val2_comple = F.multilabel_soft_margin_loss(
                co_logits[3 * input1.size(0):], label2_comple)

            loss_val2_1 = F.multilabel_soft_margin_loss(
                logits2, torch.cat([label1, label1]))
            loss_val2_2 = F.multilabel_soft_margin_loss(
                co_logits2[:2 * input1.size(0)], torch.cat([label1, label1]))
            loss_val3_1 = F.multilabel_soft_margin_loss(
                logits3, torch.cat([label1, label1]))
            loss_val3_2 = F.multilabel_soft_margin_loss(
                co_logits3[:2 * input1.size(0)], torch.cat([label1, label1]))
            loss_val4_1 = F.multilabel_soft_margin_loss(
                logits4, torch.cat([label1, label1]))
            loss_val4_2 = F.multilabel_soft_margin_loss(
                co_logits4[:2 * input1.size(0)], torch.cat([label1, label1]))
            # print(loss_val,loss_val2)

            ## use co-attention
            loss_val = loss_val1 + loss_val2 + loss_val2_1 + loss_val2_2 + loss_val3_1 + loss_val3_2 + loss_val4_1 + loss_val4_2
            ## don't use co-attention
            # loss_val=loss_val1+loss_val2_1+loss_val3_1+loss_val4_1
            # loss_val=loss_val4_1+loss_val4_2
            # print(loss_val)
            # print(logits.size())
            # print(logits[0])
            # print(label[0])

            if current_epoch >= 2:
                if (label1_comple > 0).any():
                    loss_val = loss_val + loss_val1_comple
                if (label2_comple > 0).any():
                    loss_val = loss_val + loss_val2_comple

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            # print(loss_val.data.item())
            losses.update(loss_val.data.item(),
                          input1.size()[0] + input2.size()[0])
            losses1.update(loss_val1.data.item(),
                           input1.size()[0] + input2.size()[0])
            losses2.update(loss_val2.data.item(),
                           input1.size()[0] + input2.size()[0])
            losses2_1.update(loss_val2_1.data.item(),
                             input1.size()[0] + input2.size()[0])
            losses2_2.update(loss_val2_2.data.item(),
                             input1.size()[0] + input2.size()[0])
            losses3_1.update(loss_val3_1.data.item(),
                             input1.size()[0] + input2.size()[0])
            losses3_2.update(loss_val3_2.data.item(),
                             input1.size()[0] + input2.size()[0])
            losses4_1.update(loss_val4_1.data.item(),
                             input1.size()[0] + input2.size()[0])
            losses4_2.update(loss_val4_2.data.item(),
                             input1.size()[0] + input2.size()[0])

            if (label1_comple > 0).any():
                losses1_comple.update(loss_val1_comple.data.item(),
                                      input1.size()[0] + input2.size()[0])
            if (label2_comple > 0).any():
                losses2_comple.update(loss_val2_comple.data.item(),
                                      input1.size()[0] + input2.size()[0])

            batch_time.update(time.time() - end)
            end = time.time()

            global_counter += 1
            if global_counter % 1000 == 0:
                losses.reset()
                losses1.reset()
                losses2.reset()
                losses2_1.reset()
                losses2_2.reset()
                losses3_1.reset()
                losses3_2.reset()
                losses4_1.reset()
                losses4_2.reset()

                losses1_comple.reset()
                losses2_comple.reset()

            if global_counter % args.disp_interval == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'LR: {:.5f}\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          current_epoch,
                          global_counter % len(train_loader),
                          len(train_loader),
                          optimizer.param_groups[0]['lr'],
                          loss=losses))
                print(losses.avg, losses1.avg, losses2.avg, losses2_1.avg,
                      losses2_2.avg, losses3_1.avg, losses3_2.avg,
                      losses4_1.avg, losses4_2.avg, losses1_comple.avg,
                      losses2_comple.avg)

        # if current_epoch == args.epoch-1:
        save_checkpoint(args, {
            'epoch': current_epoch,
            'global_counter': global_counter,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        },
                        is_best=False,
                        filename='%s_epoch_%d.pth' %
                        (args.dataset, current_epoch))
        current_epoch += 1
コード例 #7
0
ファイル: train_cam_spa.py プロジェクト: Panxjia/SPA_CVPR2021
def train(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str

    # for reproducibility
    if args.seed is not None:
        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = False
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    else:
        cudnn.benchmark = True

    print('Running parameters:\n')
    print(json.dumps(vars(args), indent=4, separators=(',', ':')))

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        config = json.dumps(vars(args), indent=4, separators=(',', ':'))
        fw.write(config)

    log_head = '#epoch \t loss \t pred@1 \t pred@5'
    batch_time = AverageMeter()
    losses = AverageMeter()
    if args.ram:
        losses_ra = AverageMeter()
        log_head += 'loss_ra \t '
    log_head += '\n'
    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        fw.write(log_head)
    top1 = AverageMeter()
    top5 = AverageMeter()
    args.device = torch.device('cuda') if args.gpus[0] >= 0 else torch.device(
        'cpu')
    model, optimizer = get_model(args)

    model.train()
    train_loader = data_loader(args)

    # construct writer
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    writer = SummaryWriter(log_dir=args.log_dir)

    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch
    end = time.time()
    max_iter = total_epoch * len(train_loader)
    print('Max iter:', max_iter)
    while current_epoch < total_epoch:
        model.train()
        losses.reset()
        if args.ram:
            losses_ra.reset()

        top1.reset()
        top5.reset()
        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)

        if res:
            with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                      'a') as fw:
                for g in optimizer.param_groups:
                    out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr'])
                    fw.write(out_str)

        steps_per_epoch = len(train_loader)
        for idx, dat in enumerate(train_loader):
            img_path, img, label = dat
            global_counter += 1
            img, label = img.to(args.device), label.to(args.device)

            logits, _, _ = model(img)

            loss_val, loss_ra = model.module.get_loss(logits,
                                                      label,
                                                      epoch=current_epoch,
                                                      ram_start=args.ram_start)

            # write into tensorboard
            writer.add_scalar('loss_val', loss_val, global_counter)

            # network parameter update
            optimizer.zero_grad()
            # if args.mixp:
            #     with amp.scale_loss(loss_val, optimizer) as scaled_loss:
            #         scaled_loss.backward()
            # else:
            loss_val.backward()
            optimizer.step()

            logits = torch.mean(torch.mean(logits, dim=2), dim=2)
            if not args.onehot == 'True':
                prec1, prec5 = evaluate.accuracy(logits.data,
                                                 label.long(),
                                                 topk=(1, 5))
                top1.update(prec1[0], img.size()[0])
                top5.update(prec5[0], img.size()[0])

            losses.update(loss_val.data, img.size()[0])
            if args.ram:
                losses_ra.update(loss_ra.data, img.size()[0])
            batch_time.update(time.time() - end)

            end = time.time()
            if global_counter % args.disp_interval == 0:
                # Calculate ETA
                eta_seconds = (
                    (total_epoch - current_epoch) * steps_per_epoch +
                    (steps_per_epoch - idx)) * batch_time.avg
                eta_str = "{:0>8}".format(
                    str(datetime.timedelta(seconds=int(eta_seconds))))
                eta_seconds_epoch = steps_per_epoch * batch_time.avg
                eta_str_epoch = "{:0>8}".format(
                    str(datetime.timedelta(seconds=int(eta_seconds_epoch))))
                log_output= 'Epoch: [{0}][{1}/{2}] \t ' \
                            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t ' \
                            'ETA {eta_str}({eta_str_epoch})\t ' \
                            'Loss {loss.val:.4f} ({loss.avg:.4f})\t ' \
                            'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t ' \
                            'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(current_epoch,
                            global_counter % len(train_loader), len(train_loader), batch_time=batch_time,
                            eta_str=eta_str, eta_str_epoch=eta_str_epoch, loss=losses, top1=top1, top5=top5)
                if args.ram:
                    log_output += 'Loss_ra {loss_ra.val:.4f} ({loss_ra.avg:.4f})\t'.format(
                        loss_ra=losses_ra)
                print(log_output)
                writer.add_scalar('top1', top1.avg, global_counter)
                writer.add_scalar('top5', top5.avg, global_counter)

        current_epoch += 1
        if current_epoch % 10 == 0:
            save_checkpoint(args, {
                'epoch': current_epoch,
                'arch': args.arch,
                'global_counter': global_counter,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
                            is_best=False,
                            filename='%s_epoch_%d.pth.tar' %
                            (args.dataset, current_epoch))

        with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                  'a') as fw:
            log_output = '{} \t {:.4f} \t {:.3f} \t {:.3f} \t'.format(
                current_epoch, losses.avg, top1.avg, top5.avg)
            if args.ram:
                log_output += '{:.4f}'.format(losses_ra.avg)
            log_output += '\n'
            fw.write(log_output)

        losses.reset()
        if args.ram:
            losses_ra.reset()
        top1.reset()
        top5.reset()