def __init__(self, args, data, label_flag=None, v=None, logger=None):
        self.args = args
        self.batch_size = args.batch_size
        self.data_workers = 6

        self.data = data
        self.label_flag = label_flag

        self.num_class = data.num_class
        self.num_task = args.batch_size
        self.num_to_select = 0

        #GNN
        self.gnnModel = models.create('gnn', args).cuda()
        self.projector = ProjectNetwork(self.args, 800, 4096).cuda()
        self.classifier = Classifier(self.args).cuda()
        self.meter = meter(args.num_class)
        self.v = v

        # CE for node
        if args.loss == 'focal':
            self.criterionCE = FocalLoss().cuda()
        elif args.loss == 'nll':
            self.criterionCE = nn.NLLLoss(reduction='mean').cuda()

        # BCE for edge
        self.criterion = nn.BCELoss(reduction='mean').cuda()
        self.global_step = 0
        self.logger = logger
        self.val_acc = 0
        self.threshold = args.threshold
def evaluate(model, dev_dataset, print_report=True):
    logging.info("evaluate....................")
    model.eval()
    criterion = FocalLoss(gamma=2)
    dev_sampler = SequentialSampler(dev_dataset)
    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=config.batch_size,
                                sampler=dev_sampler,
                                collate_fn=collate_fn)
    dev_loss = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for i, batch in enumerate(dev_dataloader):
            attention_mask, token_type_ids = torch.tensor([1]), torch.tensor(
                [1])
            if "bert" in config.model_name:
                token_ids, attention_mask, token_type_ids, labels, tokens = batch
            else:
                token_ids, labels, tokens = batch
            if config.use_cuda:
                token_ids = token_ids.to(config.device)
                attention_mask = attention_mask.to(config.device)
                token_type_ids = token_type_ids.to(config.device)
                labels = labels.to(config.device)
            outputs = model((token_ids, attention_mask, token_type_ids))
            # loss = F.cross_entropy(outputs, labels)
            loss = criterion(outputs, labels)
            dev_loss += loss
            labels = labels.data.cpu().numpy()
            predicts = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predicts)

    acc = metrics.accuracy_score(labels_all, predict_all)
    if print_report:
        report = metrics.classification_report(
            labels_all,
            predict_all,
            target_names=config.label2id.keys(),
            digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        logging.info("result: \n%s" % report)
        logging.info("Confusion matrix: \n%s" % confusion)
        return acc, dev_loss / len(dev_dataloader), report, confusion
    return acc, dev_loss / len(dev_dataloader)
Example #3
0
    def test(loader):
        G.eval()
        F1.eval()
        test_loss = 0
        correct = 0
        size = 0
        num_class = len(class_list)
        output_all = np.zeros((0, num_class))

        # Setting the loss function to be used for the classification loss
        if args.loss == 'CE':
            criterion = nn.CrossEntropyLoss().to(device)
        if args.loss == 'FL':
            criterion = FocalLoss(alpha=1, gamma=args.gamma).to(device)
        if args.loss == 'CBFL':
            # Calculating the list having the number of examples per class which is going to be used in the CB focal loss
            beta = args.beta
            effective_num = 1.0 - np.power(beta, class_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                class_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
            criterion = CBFocalLoss(weight=per_cls_weights,
                                    gamma=args.gamma).to(device)

        confusion_matrix = torch.zeros(num_class, num_class)
        with torch.no_grad():
            for batch_idx, data_t in enumerate(loader):
                im_data_t.data.resize_(data_t[0].size()).copy_(data_t[0])
                gt_labels_t.data.resize_(data_t[1].size()).copy_(data_t[1])
                feat = G(im_data_t)
                output1 = F1(feat)
                output_all = np.r_[output_all, output1.data.cpu().numpy()]
                size += im_data_t.size(0)
                pred1 = output1.data.max(1)[1]
                for t, p in zip(gt_labels_t.view(-1), pred1.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1
                correct += pred1.eq(gt_labels_t.data).cpu().sum()
                test_loss += criterion(output1, gt_labels_t) / len(loader)
        np.save("cf_target.npy", confusion_matrix)
        #print(confusion_matrix)
        print('\nTest set: Average loss: {:.4f}, '
              'Accuracy: {}/{} F1 ({:.0f}%)\n'.format(test_loss, correct, size,
                                                      100. * correct / size))
        return test_loss.data, 100. * float(correct) / size
Example #4
0
def train_model(model, train_dataloader, valid_dataloader, fold=None):
    num_epochs = 30

    print("| -- Training Model -- |")

    model_path = os.path.join("model", f"model_{fold}.pt")
    best_model_path = os.path.join("model", f"best_model_{fold}.pt")

    model, start_epoch, max_iout = load_state(model, best_model_path)

    writer = SummaryWriter()

    # define loss function (criterion) and optimizer
    # criterion = FocalLoss(gamma=2, logits=True)
    criterion = FocalLoss(gamma=2)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # best_model_wts = copy.deepcopy(model.state_dict())  # deepcopy important

    for epoch in range(start_epoch, num_epochs + 1):
        lr = optimizer.param_groups[0]['lr']
        print(f"| Epoch [{epoch:3d}/{num_epochs:3d}], lr={lr}")

        train_metrics = train(model, train_dataloader, epoch, num_epochs,
                              criterion, optimizer, False)
        valid_metrics = evaluate(model, valid_dataloader, epoch, criterion,
                                 False)

        writer.add_scalar('loss/train', train_metrics["train_loss"], epoch)
        writer.add_scalar('iout/train', train_metrics["train_iout"], epoch)
        writer.add_scalar('loss/valid', valid_metrics["valid_loss"], epoch)
        writer.add_scalar('iout/valid', valid_metrics["valid_iout"], epoch)

        valid_iout = valid_metrics["valid_iout"]

        save_state(model, epoch, valid_iout, model_path)

        if valid_iout > max_iout:
            max_iout = valid_iout
            # best_model_wts = copy.deepcopy(model.state_dict())
            shutil.copyfile(model_path, best_model_path)
            print(f"|- Save model, Epoch #{epoch}, max_iout: {max_iout:.4f}")
Example #5
0
File: train.py Project: zjbit/MONet
 def __init__(self, Net=MONet, load_parameters=False, is_distributed=False):
     self.is_distributed = is_distributed
     self.device = torch.device(
         'cuda:1' if torch.cuda.is_available() else 'cpu')
     print('准备使用设备%s训练网络' % self.device)
     '''实例化模型,并加载参数'''
     self.net = Net().train()
     if os.path.exists('weights/net.pth') and load_parameters:
         self.net.load_state_dict(torch.load('weights/net.pth'))
         print('成功加载模型参数')
     elif not load_parameters:
         print('未能加载模型参数')
     else:
         raise RuntimeError('Model parameters are not loaded')
     self.net = self.net.to(self.device)
     print('模型初始化完成')
     '''实例化数据集,并实例化数据加载器'''
     self.data_set = DataSet()
     self.data_loader = DataLoader(self.data_set,
                                   len(self.data_set),
                                   True,
                                   num_workers=2)
     '''实例化损失函数'''
     self.f_loss = FocalLoss(2, 0.8)
     self.f_loss_mc = FocalLossManyClassification(6)
     # self.f_loss_mc = nn.CrossEntropyLoss()
     self.mse_loss = torch.nn.MSELoss().to(self.device)
     print('损失函数初始化完成')
     '''实例化优化器'''
     self.optimizer = torch.optim.Adam(self.net.parameters())
     if os.path.exists('optimizer.pth') and load_parameters:
         self.optimizer.load_state_dict(torch.load('optimizer.pth'))
         print('成功加载训练器参数')
     elif load_parameters:
         raise Warning('未能正确加载优化器参数')
     else:
         print('优化器初始化完成')
     self.sigmoid = torch.nn.Sigmoid()
     '''开启分布式训练'''
     if is_distributed:
         self.net = nn.parallel.DataParallel(self.net)
 def make_loss_fn(self):
     for task in self.active_tasks:
         loss_fn = self.cfg[task]['loss']
         if loss_fn == 'cross_entropy2d':
             self.loss_fn[task] = nn.CrossEntropyLoss(
                 self.weights[task].cuda(), ignore_index=255)
         elif loss_fn == 'weighted_binary_cross_entropy':
             self.loss_fn[task] = WeightedBinaryCrossEntropy()
         elif loss_fn == 'weighted_multi_class_binary_cross_entropy':
             self.loss_fn[task] = WeightedMultiClassBinaryCrossEntropy()
         elif loss_fn == 'focalloss':
             self.loss_fn[task] = FocalLoss()
         elif loss_fn == 'dualityfocalloss':
             self.loss_fn[task] = DualityFocalLoss()
         elif loss_fn == 'dualityceloss':
             self.loss_fn[task] = DualityCELoss(self.weights[task].cuda(),
                                                ignore_index=255)
         elif loss_fn == 'huber_loss':
             self.loss_fn[task] = HuberLoss()
         elif loss_fn == 'bbox_loss':
             self.loss_fn[task] = BboxLoss()
Example #7
0
    def __init__(self, args, data, step=0, label_flag=None, v=None, logger=None):
        self.args = args
        self.batch_size = args.batch_size
        self.data_workers = 6

        self.step = step
        self.data = data
        self.label_flag = label_flag

        self.num_class = data.num_class
        self.num_task = args.batch_size
        self.num_to_select = 0

        self.model = models.create(args.arch, args)
        self.model = nn.DataParallel(self.model).cuda()

        #GNN
        self.gnnModel = models.create('gnn', args)
        self.gnnModel = nn.DataParallel(self.gnnModel).cuda()

        self.meter = meter(args.num_class)
        self.v = v

        # CE for node classification
        if args.loss == 'focal':
            self.criterionCE = FocalLoss().cuda()
        elif args.loss == 'nll':
            self.criterionCE = nn.NLLLoss(reduction='mean').cuda()

        # BCE for edge
        self.criterion = nn.BCELoss(reduction='mean').cuda()
        self.global_step = 0
        self.logger = logger
        self.val_acc = 0
        self.threshold = args.threshold

        if self.args.discriminator:
            self.discriminator = Discriminator(self.args.in_features)
            self.discriminator = nn.DataParallel(self.discriminator).cuda()
Example #8
0
    def infer(loader):
        G.eval()
        F1.eval()
        test_loss = 0
        correct = 0
        size = 0
        num_class = len(class_list)
        output_all = np.zeros((0, num_class))

        # Setting the loss function to be used for the classification loss
        if args.loss == 'CE':
            criterion = nn.CrossEntropyLoss().to(device)
        if args.loss == 'FL':
            criterion = FocalLoss(alpha=1, gamma=1).to(device)
        if args.loss == 'CBFL':
            # Calculating the list having the number of examples per class which is going to be used in the CB focal loss
            beta = 0.99
            effective_num = 1.0 - np.power(beta, class_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                class_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
            criterion = CBFocalLoss(weight=per_cls_weights,
                                    gamma=0.5).to(device)
        # defining a nested list to store the cosine similarity (or distances) of the vectors from the class prototypes
        class_dist_list = []
        for i in range(num_class):
            empty_dists = []
            class_dist_list.append(empty_dists)

        confusion_matrix = torch.zeros(num_class, num_class)
        # iterating through the elements of the batch in the dataloader
        with torch.no_grad():
            for batch_idx, data_t in enumerate(loader):
                im_data_t.data.resize_(data_t[0].size()).copy_(data_t[0])
                gt_labels_t.data.resize_(data_t[1].size()).copy_(data_t[1])
                feat = G(im_data_t)
                output1 = F1(feat)
                output_all = np.r_[output_all, output1.data.cpu().numpy()]
                size += im_data_t.size(0)
                pred1 = output1.data.max(1)[1]
                # filling the elements of the confusion matrix
                for t, p in zip(gt_labels_t.view(-1), pred1.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1
                correct += pred1.eq(gt_labels_t.data).cpu().sum()
                test_loss += criterion(output1, gt_labels_t) / len(loader)
                pred1 = pred1.cpu().numpy()
                dists = output1.data.max(1)[0]
                dists = dists.cpu().numpy()

                # forming the lists of the distances of the predicted labels and the class prototype
                for label, dist in zip(pred1, dists):
                    label = int(label)
                    class_dist_list[label].append(dist)

        # sorting the distances in ascending order for each of the classes, also finding a threshold for similarity of each class
        summ = 0
        class_dist_threshold_list = []
        for class_ in range(len(class_dist_list)):
            class_dist_list[class_].sort()
            l = len(class_dist_list[class_])
            tenth = l / 10
            idx_tenth = math.ceil(tenth)
            class_dist_threshold_list.append(
                class_dist_list[class_][idx_tenth])

        print('\nTest set: Average loss: {:.4f}, '
              'Accuracy: {}/{} F1 ({:.2f}%)\n'.format(test_loss, correct, size,
                                                      100. * correct / size))
        return test_loss.data, 100. * float(
            correct) / size, class_dist_threshold_list
Example #9
0
                              pin_memory=True)

    val_dataset = IDRND_3D_dataset(mode=config['mode'].replace('train', 'val'),
                                   use_face_detection=str2bool(
                                       config['use_face_detection']),
                                   double_loss_mode=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=config['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)

    model = DenchikModel(n_features=32).to(device)
    summary(model, (3, 5, 224, 224), device="cuda")

    criterion = FocalLoss().to(device)
    criterion4class = CrossEntropyLoss().to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config['learning_rate'],
                                 weight_decay=config['weight_decay'])
    scheduler = ExponentialLR(optimizer, gamma=0.8)

    shutil.rmtree(config['log_path'])
    os.mkdir(config['log_path'])
    writer = SummaryWriter(log_dir=config['log_path'])

    global_step = 0
    for epoch in trange(config['number_epochs']):
        model.train()
        train_bar = tqdm(train_loader)
Example #10
0
    def train(class_dist_threshold_list):
        G.train()
        F1.train()
        optimizer_g = optim.SGD(params,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)
        optimizer_f = optim.SGD(list(F1.parameters()),
                                lr=1.0,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)

        def zero_grad_all():
            optimizer_g.zero_grad()
            optimizer_f.zero_grad()

        param_lr_g = []
        for param_group in optimizer_g.param_groups:
            param_lr_g.append(param_group["lr"])
        param_lr_f = []
        for param_group in optimizer_f.param_groups:
            param_lr_f.append(param_group["lr"])

        # Setting the loss function to be used for the classification loss
        if args.loss == 'CE':
            criterion = nn.CrossEntropyLoss().to(device)
        if args.loss == 'FL':
            criterion = FocalLoss(alpha=1, gamma=args.gamma).to(device)
        if args.loss == 'CBFL':
            # Calculating the list having the number of examples per class which is going to be used in the CB focal loss
            beta = args.beta
            effective_num = 1.0 - np.power(beta, class_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                class_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
            criterion = CBFocalLoss(weight=per_cls_weights,
                                    gamma=args.gamma).to(device)

        all_step = args.steps
        data_iter_s = iter(source_loader)
        data_iter_t = iter(target_loader)
        data_iter_t_unl = iter(target_loader_unl)
        len_train_source = len(source_loader)
        len_train_target = len(target_loader)
        len_train_target_semi = len(target_loader_unl)
        best_acc = 0
        counter = 0
        """
        x = torch.load("./freezed_models/alexnet_p2r.ckpt.best.pth.tar")
        G.load_state_dict(x['G_state_dict'])
        F1.load_state_dict(x['F1_state_dict'])
        optimizer_f.load_state_dict(x['optimizer_f'])
        optimizer_g.load_state_dict(x['optimizer_g'])
        """
        reg_weight = args.reg
        for step in range(all_step):
            optimizer_g = inv_lr_scheduler(param_lr_g,
                                           optimizer_g,
                                           step,
                                           init_lr=args.lr)
            optimizer_f = inv_lr_scheduler(param_lr_f,
                                           optimizer_f,
                                           step,
                                           init_lr=args.lr)
            lr = optimizer_f.param_groups[0]['lr']
            # condition for restarting the iteration for each of the data loaders
            if step % len_train_target == 0:
                data_iter_t = iter(target_loader)
            if step % len_train_target_semi == 0:
                data_iter_t_unl = iter(target_loader_unl)
            if step % len_train_source == 0:
                data_iter_s = iter(source_loader)
            data_t = next(data_iter_t)
            data_t_unl = next(data_iter_t_unl)
            data_s = next(data_iter_s)
            with torch.no_grad():
                im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
                gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
                im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
                gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
                im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])

            zero_grad_all()
            if args.uda == 1:
                data = im_data_s
                target = gt_labels_s
            else:
                data = torch.cat((im_data_s, im_data_t), 0)
                target = torch.cat((gt_labels_s, gt_labels_t), 0)
            #print(data.shape)
            output = G(data)
            out1 = F1(output)
            if args.attribute is not None:
                if args.net == 'resnet34':
                    reg_loss = regularizer(F1.fc3.weight, att)
                    loss = criterion(out1, target) + reg_weight * reg_loss
                else:
                    reg_loss = regularizer(F1.fc2.weight, att)
                    loss = criterion(out1, target) + reg_weight * reg_loss
            else:
                reg_loss = torch.tensor(0)
                loss = criterion(out1, target)

            if args.attribute is not None:
                if step % args.save_interval == 0 and step != 0:
                    reg_weight = 0.5 * reg_weight
                    print("Reduced Reg weight to: ", reg_weight)

            loss.backward(retain_graph=True)
            optimizer_g.step()
            optimizer_f.step()
            zero_grad_all()
            if not args.method == 'S+T':
                output = G(im_data_tu)
                if args.method == 'ENT':
                    loss_t = entropy(F1, output, args.lamda)
                    #print(loss_t.cpu().data.item())
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                elif args.method == 'MME':
                    loss_t = adentropy(F1, output, args.lamda,
                                       class_dist_threshold_list)
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                else:
                    raise ValueError('Method cannot be recognized.')
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Reg: {:.6f} Loss T {:.6f} ' \
                            'Method {}\n'.format(args.source, args.target,
                                                step, lr, loss.data, reg_weight*reg_loss.data,
                                                -loss_t.data, args.method)
            else:
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Reg: {:.6f} Method {}\n'.\
                    format(args.source, args.target,
                        step, lr, loss.data, reg_weight * reg_loss.data,
                        args.method)
            G.zero_grad()
            F1.zero_grad()
            zero_grad_all()
            if step % args.log_interval == 0:
                print(log_train)
            if step % args.save_interval == 0 and step > 0:
                loss_val, acc_val = test(target_loader_val)
                loss_test, acc_test = test(target_loader_test)
                G.train()
                F1.train()
                if acc_val >= best_acc:
                    best_acc = acc_val
                    best_acc_test = acc_test
                    counter = 0
                else:
                    counter += 1
                if args.early:
                    if counter > args.patience:
                        break
                print('best acc test %f best acc val %f' %
                      (best_acc_test, acc_val))
                print('record %s' % record_file)
                with open(record_file, 'a') as f:
                    f.write('step %d best %f final %f \n' %
                            (step, best_acc_test, acc_val))
                G.train()
                F1.train()
                #saving model as a checkpoint dict having many things
                if args.save_check:
                    print('saving model')
                    is_best = True if counter == 0 else False
                    save_mymodel(
                        args, {
                            'step': step,
                            'arch': args.net,
                            'G_state_dict': G.state_dict(),
                            'F1_state_dict': F1.state_dict(),
                            'best_acc_test': best_acc_test,
                            'optimizer_g': optimizer_g.state_dict(),
                            'optimizer_f': optimizer_f.state_dict(),
                        }, is_best, time_stamp)
Example #11
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        if args.distributed:
            if args.local_rank ==0:
                self.saver = Saver(args)
                self.saver.save_experiment_config()
                # Define Tensorboard Summary
                self.summary = TensorboardSummary(self.saver.experiment_dir)
                self.writer = self.summary.create_summary()
        else:
            self.saver = Saver(args)
            self.saver.save_experiment_config()
            # Define Tensorboard Summary
            self.summary = TensorboardSummary(self.saver.experiment_dir)
            self.writer = self.summary.create_summary()

        # PATH = args.path
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        # self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = SCNN(nclass=self.nclass,backbone=args.backbone,output_stride=args.out_stride,cuda = args.cuda,extension=args.ext)


        # Define Optimizer
        # optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(model.parameters(), args.lr,weight_decay=args.weight_decay)

        # model, optimizer = amp.initialize(model,optimizer,opt_level="O1")

        # Define Criterion
        weight = None
        # criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = FocalLoss(gamma=0, alpha=[0.2, 0.98], img_size=512*512)
        self.criterion1 = FocalLoss(gamma=5, alpha=[0.2, 0.98], img_size=512 * 512)
        self.criterion2 = disc_loss(delta_v=0.5, delta_d=3.0, param_var=1.0, param_dist=1.0,
                                    param_reg=0.001, EMBEDDING_FEATS_DIMS=21,image_shape=[512,512])

        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.val_loader),local_rank=args.local_rank)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()
            if args.distributed:
                self.model = DistributedDataParallel(self.model)
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)


        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            filename = 'checkpoint.pth.tar'
            args.resume = os.path.join(self.saver.experiment_dir, filename)
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # if args.cuda:
            #     self.model.module.load_state_dict(checkpoint['state_dict'])
            # else:
            self.model.load_state_dict(checkpoint['state_dict'])
            # if not args.ft:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
def train(epoch=400):
    # 创建指标计算对象
    evaluator = Evaluator(8)

    # 定义好最好的指标miou数值, 初始化为0
    best_pred = 0.0

    # 写入日志
    writer = SummaryWriter(cfg.LOG_DIR)

    # 指定GPU
    device = torch.device(0)

    # 创建数据
    train_dataset = LaneDataset(csv_file=cfg.TRAIN_CSV_FILE,
                                transform=transforms.Compose([
                                    ImageAug(),
                                    DeformAug(),
                                    CutOut(64, 0.5),
                                    ToTensor()
                                ]))
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=cfg.BATCHES,
                                  shuffle=cfg.TRAIN_SHUFFLE,
                                  num_workers=cfg.DATA_WORKERS,
                                  drop_last=True)
    val_dataset = LaneDataset(csv_file=cfg.VAL_CSV_FILE,
                              transform=transforms.Compose([ToTensor()]))
    val_dataloader = DataLoader(val_dataset,
                                batch_size=cfg.BATCHES,
                                shuffle=cfg.VAL_TEST_SHUFFLE,
                                num_workers=cfg.DATA_WORKERS)

    # 模型构建
    model = DeepLabV3p()
    model = model.to(device)

    # 损失函数和优化器
    if cfg.LOSS == 'ce':
        criterion = nn.CrossEntropyLoss().to(device)
    elif cfg.LOSS == 'focal':
        criterion = FocalLoss().to(device)
    elif cfg.LOSS == 'focalTversky':
        criterion = FocalTversky_loss().to(device)

    optimizer = opt.Adam(model.parameters(), lr=cfg.TRAIN_LR)

    for epo in range(epoch):
        # 训练部分
        train_loss = 0
        model.train()
        for index, batch_item in enumerate(train_dataloader):
            image, mask = batch_item['image'].to(
                device), batch_item['mask'].to(device)
            optimizer.zero_grad()
            output = model(image)
            loss = criterion(output, mask)
            loss.backward()
            # 取出loss数值
            iter_loss = loss.item()
            train_loss += loss
            optimizer.step()

            if np.mod(index, 8) == 0:
                line = 'epoch {}, {}/{}, train loss is {}'.format(
                    epo, index, len(train_dataloader), iter_loss)
                print(line)
                with open(os.path.join(cfg.LOG_DIR, 'log.txt'), 'a') as f:
                    f.write(line)
                    f.write('\r\n')

        #验证部分
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for index, batch_item in enumerate(val_dataloader):
                image, mask = batch_item['image'].to(
                    device), batch_item['mask'].to(device)

                optimizer.zero_grad()
                output = model(image)
                loss = criterion(output, mask)
                iter_loss = loss.item()
                val_loss += iter_loss

                # 记录相关指标
                pred = output.cpu().numpy()
                mask = mask.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                evaluator.add_batch(mask, pred)

        line_epoch = 'epoch train loss = %.3f, epoch val loss = %.3f' % (
            train_loss / len(train_dataloader), val_loss / len(val_dataloader))
        print(line_epoch)
        with open(os.path.join(cfg.LOG_DIR, 'log.txt'), 'a') as f:
            f.write(line)
            f.write('\r\n')

        ACC = evaluator.Pixel_Accuracy()
        mIoU = evaluator.Mean_Intersection_over_Union()

        # tensorboard记录
        writer.add_scalar('train_loss', train_loss / len(train_dataloader),
                          epo)
        writer.add_scalar('val_loss', val_loss / len(val_dataloader), epo)
        writer.add_scalar('Acc', ACC, epo)
        writer.add_scalar('mIoU', mIoU, epo)

        # 每次验证,根据新得出的mIoU指标来保存模型
        new_pred = mIoU
        if new_pred > best_pred:
            best_pred = new_pred
            save_path = os.path.join(
                cfg.MODEL_SAVE_DIR,
                '{}_{}_{}_{}_{}.pth'.format(cfg.BACKBONE, cfg.LAYERS,
                                            cfg.NORM_LAYER, cfg.LOSS, epo))

            torch.save(model.state_dict(), save_path)
Example #13
0
                                use_face_detection=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=config['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)

    # model = Model(base_model = fishnet99())
    model = Model(base_model=EfficientNet.from_pretrained('efficientnet-b3'))
    #model = Model(base_model=resnet34(pretrained=True))
    summary(model, (3, 224, 224), device="cpu")

    dataowner = DataOwner(train_loader, val_loader, None)
    # criterion = torch.nn.BCELoss()
    # criterion = WeightedBCELoss(weights=[0.49, 0.51]) Не обучается
    criterion = FocalLoss()

    shutil.rmtree('../output/logs')
    os.mkdir('../output/logs')

    keker = Keker(
        model=model,
        dataowner=dataowner,
        criterion=criterion,
        target_key="label",
        metrics={
            "acc": bce_accuracy,
            "idrnd_score": idrnd_score_pytorch,
            "FAR": far_score,
            # "roc_auc": roc_auc,
            "FRR": frr_score
Example #14
0
                                                path_g=path_g,
                                                path_g2l=path_g2l,
                                                path_l2g=path_l2g)

###################################
num_epochs = args.num_epochs
learning_rate = args.lr
lamb_fmreg = args.lamb_fmreg

optimizer = get_optimizer(model, mode, learning_rate=learning_rate)

scheduler = LR_Scheduler('poly', learning_rate, num_epochs,
                         len(dataloader_train))
##################################

criterion1 = FocalLoss(gamma=3, ignore=0)
criterion2 = nn.CrossEntropyLoss()
criterion3 = lovasz_softmax
criterion = lambda x, y: criterion1(x, y)
# criterion = lambda x,y: 0.5*criterion1(x, y) + 0.5*criterion3(x, y)
mse = nn.MSELoss()

if not evaluation:
    writer = SummaryWriter(log_dir=log_path + task_name)
    f_log = open(log_path + task_name + ".log", 'w')

trainer = Trainer(criterion, optimizer, n_class, size_g, size_p,
                  sub_batch_size, mode, lamb_fmreg)
evaluator = Evaluator(n_class, size_g, size_p, sub_batch_size, mode, test)

best_pred = 0.0
Example #15
0
def train(args, model):
    writer = SummaryWriter(comment=args.model)

    crop_size = None
    if args.crop_size:
        crop_size = args.crop_size.split('x')
        crop_size = tuple([int(x) for x in crop_size])
    if args.dataset == 'CelebA':
        dataset = CelebA(args.datadir,
                         resize=args.resize,
                         argument=args.argument)
    elif args.dataset == 'Figaro':
        dataset = Figaro(args.datadir,
                         resize=args.resize,
                         crop_size=crop_size,
                         argument=args.argument)
    elif args.dataset == 'Our':
        dataset = OurDataset(args.datadir,
                             resize=args.resize,
                             argument=args.argument)
    elif args.dataset == 'CelebAMaskHQ':
        dataset = CelebAMaskHQ(args.datadir,
                               resize=args.resize,
                               argument=args.argument)
    elif args.dataset == 'SpringFace':
        dataset = SpringFace(args.datadir,
                             resize=args.resize,
                             argument=args.argument)
    elif args.dataset == 'SpringHair':
        dataset = SpringHair(args.datadir,
                             resize=args.resize,
                             crop_size=crop_size,
                             argument=args.argument)
    else:
        print('Fail to find the dataset')
        raise ValueError

    num_train = int(args.train_val_rate * len(dataset))
    num_val = len(dataset) - num_train
    train_dataset, val_dataset = random_split(dataset, [num_train, num_val])
    worker_init_fn = lambda _: np.random.seed(
        int(torch.initial_seed()) % (2**32 - 1))
    train_loader = DataLoader(train_dataset,
                              num_workers=args.num_workers,
                              batch_size=args.batch_size,
                              shuffle=True,
                              worker_init_fn=worker_init_fn)

    device = torch.device("cuda" if args.gpu else "cpu")

    writer.add_graph(model,
                     torch.zeros(args.batch_size, 3, 218, 178).to(device))

    model.train()

    if args.model == 'unet':
        criterion = nn.CrossEntropyLoss().to(device)
    elif args.model == 'denseunet':
        # criterion = nn.CrossEntropyLoss().to(device)
        criterion = FocalLoss(gamma=2).to(device)
    else:
        print('Fail to find the net')
        raise ValueError

    optimizer = Adam(model.parameters(), lr=args.lr)

    max_mean_iu = -999

    n_steps = 0
    for i_epoch in range(args.num_epochs):
        model.train()

        for step, (images, labels, _) in enumerate(train_loader):
            # print(step)
            inputs = images.to(device)
            targets = labels.to(device)
            # print('input', inputs.shape)
            # print('target', targets.shape)
            outputs = model(inputs).squeeze(1)
            # print('output', outputs.shape)

            optimizer.zero_grad()

            loss = criterion(outputs, targets)
            loss.backward()

            writer.add_scalar('/train/loss', loss, n_steps)

            optimizer.step()
            n_steps += 1

            if n_steps % args.iters_eval == 0 or \
                    (i_epoch == args.num_epochs - 1 and step == len(train_loader) - 1):
                result = evaluate_segment(args, model, val_dataset)
                time_stamp = time.strftime("%Y-%m-%d %H:%M:%S",
                                           time.localtime())
                print(
                    f'epoch = {i_epoch}, iter = {n_steps},  train_loss = {loss} ---{time_stamp}'
                )
                print(result)

                for key in result.keys():
                    writer.add_scalar(f'/val/{key}', result[key], i_epoch)

                for i in range(6):
                    img, label, img_name = val_dataset[i]
                    writer.add_text('val_img_name', img_name, i_epoch)

                    output = model(img.unsqueeze(0).to(device))
                    mask = torch.argmax(output, 1)

                    img = UnNormalize(mean=[.485, .456, .406],
                                      std=[.229, .224, .225])(img)
                    img_masked = overlay_segmentation_mask(img,
                                                           mask,
                                                           inmode='tensor',
                                                           outmode='tensor')

                    mask = mask * 127
                    label = label * 127

                    writer.add_image(f'/val/image/{img_name}', img, n_steps)
                    writer.add_image(f'/val/label/{img_name}',
                                     label.unsqueeze(0), n_steps)
                    writer.add_image(f'/val/image_masked/{img_name}',
                                     img_masked, n_steps)
                    writer.add_image(f'/val/mask/{img_name}', mask, n_steps)

                if result['mean_iu'] > max_mean_iu:
                    max_mean_iu = result['mean_iu']
                    torch.save(model.state_dict(), args.save_path)
Example #16
0
def main():

    config = "config/cocostuff.yaml"
    cuda = True
    device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))
    CONFIG.SAVE_DIR = osp.join(CONFIG.SAVE_DIR, CONFIG.EXPERIENT)
    CONFIG.LOGNAME = osp.join(CONFIG.SAVE_DIR, "log.txt")
    if not os.path.exists(CONFIG.SAVE_DIR):
        os.mkdir(CONFIG.SAVE_DIR)

    # Dataset
    dataset = MultiDataSet(cropSize=50, inSize = 300, testFlag=False, preload=True)
    # dataset = MultiDataSet(
    #     CONFIG.ROOT,
    #     CONFIG.CROPSIZE,
    #     CONFIG.INSIZE,
    #     preload=False
    # )

    # DataLoader
    if CONFIG.RESAMPLEFLAG:
        batchSizeResample = CONFIG.BATCH_SIZE
        CONFIG.BATCH_SIZE = 1

    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE,
        num_workers=CONFIG.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model
    model = fpn(CONFIG.N_CLASSES)
    model = nn.DataParallel(model)

    # read old version
    if CONFIG.ITER_START != 1:
        load_network(CONFIG.SAVE_DIR, model, "SateFPN", "latest")
        print("load previous model succeed, training start from iteration {}".format(CONFIG.ITER_START))
    model.to(device)

    # Optimizer
    optimizer = {
        "sgd": torch.optim.SGD(
            # cf lr_mult and decay_mult in train.prototxt
            params=[
                {
                    "params": model.parameters(),
                    "lr": CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                }
            ],
            momentum=CONFIG.MOMENTUM,
        )
    }.get(CONFIG.OPTIMIZER)

    # Loss definition
    # criterion = FocalLoss(device, gamma=2)
    criterion = FocalLoss(gamma=2)
    criterion.to(device)

    #visualizer
    # vis = Visualizer(CONFIG.DISPLAYPORT)

    model.train()
    iter_start_time = time.time()
    for iteration in range(CONFIG.ITER_START, CONFIG.ITER_MAX + 1):
        # Set a learning rate
        poly_lr_scheduler(
            optimizer=optimizer,
            init_lr=CONFIG.LR,
            iter=iteration - 1,
            lr_decay_iter=CONFIG.LR_DECAY,
            max_iter=CONFIG.ITER_MAX,
            power=CONFIG.POLY_POWER,
        )

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG.ITER_SIZE + 1):
            if not CONFIG.RESAMPLEFLAG:
                try:
                    data, target = next(loader_iter)
                except:
                    loader_iter = iter(loader)
                    data, target = next(loader_iter)
            else:
                cntFrame = 0
                # clDataStart = time.time()
                clCnt = 0
                while cntFrame < batchSizeResample:
                    clCnt += 1
                    try:
                        dataOne, targetOne = next(loader_iter)
                    except:
                        loader_iter = iter(loader)
                        dataOne, targetOne = next(loader_iter)

                    hist = np.bincount(targetOne.numpy().flatten(), minlength=7)
                    hist = hist / np.sum(hist)
                    if np.nanmax(hist) <= 0.70:
                        if cntFrame == 0:
                            data = dataOne
                            target = targetOne
                        else:
                            data = torch.cat([data, dataOne])
                            target = torch.cat([target, targetOne])
                        cntFrame += 1
                # print("collate data takes %.2f sec, collect %d time" % (time.time() - clDataStart, clCnt))

            # Image
            data = data.to(device)

            # Propagate forward
            output = model(data)
            # Loss
            loss = 0
            # Resize target for {100%, 75%, 50%, Max} outputs
            target_ = resize_target(target, output.size(2))
            # classmap = class_to_target(target_, CONFIG.N_CLASSES)
            # target_ = label_bluring(classmap)  # soft crossEntropy target
            target_ = torch.from_numpy(target_).long()
            target_ = target_.to(device)
            # Compute crossentropy loss
            loss += criterion(output, target_)
            # Backpropagate (just compute gradients wrt the loss)
            loss /= float(CONFIG.ITER_SIZE)
            loss.backward()

            iter_loss += float(loss)

        # Update weights with accumulated gradients
        optimizer.step()
        # Visualizer and Summery Writer
        if iteration % CONFIG.ITER_TF == 0:
            print("itr {}, loss is {}".format(iteration, iter_loss))
            # print("itr {}, loss is {}".format(iteration, iter_loss), file=open(CONFIG.LOGNAME, "a"))  #
            # print("time taken for each iter is %.3f" % ((time.time() - iter_start_time)/iteration))
            # vis.drawLine(torch.FloatTensor([iteration]), torch.FloatTensor([iter_loss]))
            # vis.displayImg(inputImgTransBack(data), classToRGB(outputs[3][0].to("cpu").max(0)[1]),
            #                classToRGB(target[0].to("cpu")))
        # Save a model
        if iteration % CONFIG.ITER_SNAP == 0:
            save_network(CONFIG.SAVE_DIR, model, "SateFPN", iteration)

        # Save a model
        if iteration % 100 == 0:
            save_network(CONFIG.SAVE_DIR, model, "SateFPN", "latest")

    save_network(CONFIG.SAVE_DIR, model, "SateFPN", "final")
Example #17
0
								  output_shape=config['image_resolution'])
	train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=8,
							  pin_memory=True, drop_last=True)

	val_dataset = IDRND_dataset(mode=config['mode'].replace('train', 'val'), use_face_detection=str2bool(config['use_face_detection']),
								double_loss_mode=True, output_shape=config['image_resolution'])
	val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, drop_last=False)

	model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained('efficientnet-b3')).to(device)
	#model.load_state_dict(torch.load(f"../cross_val/models_weights/pretrained/EF_{8}_1.5062978111598622_0.9967353313006619.pth"))
	model.load_state_dict(
		torch.load(f"/media/danil/Data/Kaggle/IDRND_facial_antispoofing_challenge_v2/output/models/DoubleModelTwoHead/DoubleModel_17_0.01802755696873344.pth"))

	summary(model, (3, config['image_resolution'], config['image_resolution']), device='cuda')

	criterion = FocalLoss(add_weight=False).to(device)
	criterion4class = CrossEntropyLoss().to(device)

	steps_per_epoch = train_loader.__len__()
	optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
	swa = SWA(optimizer, swa_start=config['swa_start'] * steps_per_epoch, swa_freq=int(config['swa_freq'] * steps_per_epoch), swa_lr=config['learning_rate'] / 10)
	scheduler = ExponentialLR(swa, gamma=0.85)

	shutil.rmtree(config['log_path'])
	os.mkdir(config['log_path'])
	train_writer = SummaryWriter(os.path.join(config['log_path'], "train"))
	val_writer = SummaryWriter(os.path.join(config['log_path'], "val"))

	global_step = 0
	for epoch in trange(config['number_epochs']):
		model.train()
Example #18
0
def eval_epoch(net, epoch, data_loader, eval_log):
    # set to eval
    net.eval()
    total_loss = 0.0
    data_process = tqdm(data_loader)
    # init confusion matrix with all zeros
    confusion_matrix = {
        "TP": {i: 0
               for i in range(8)},
        "TN": {i: 0
               for i in range(8)},
        "FP": {i: 0
               for i in range(8)},
        "FN": {i: 0
               for i in range(8)}
    }

    for batch_item in data_process:
        # get batch data
        batch_image, batch_label = batch_item['image'], batch_item['label']
        if cfg.MULTI_GPU:
            batch_image, batch_label = batch_image.cuda(device=cfg.DEVICE_LIST[0]), \
                                       batch_label.cuda(device=cfg.DEVICE_LIST[0])
        else:
            batch_image, batch_label = batch_image.cuda(), \
                                       batch_label.cuda()
        # forward to get output
        batch_out = net(batch_image)
        # compute loss
        # batch_loss = CrossEntropyLoss(cfg.CLASS_NUM)(batch_out, batch_label)
        batch_loss = FocalLoss(cfg.CLASS_NUM)(batch_out, batch_label)
        total_loss += batch_loss.detach().item()

        # get prediction, shape and value type same as batch_label
        pred = torch.argmax(F.softmax(batch_out, dim=1), dim=1)
        # compute confusion matrix using batch data
        confusion_matrix = update_confusion_matrix(pred, batch_label,
                                                   confusion_matrix)
        # print batch result
        data_process.set_description_str("epoch:{}".format(epoch))
        data_process.set_postfix_str("batch_loss:{:.4f}".format(batch_loss))

    eval_loss = total_loss / len(data_loader)

    # compute metric
    epoch_ious = compute_iou(confusion_matrix)
    epoch_m_iou = compute_mean(epoch_ious)
    epoch_precisions = compute_precision(confusion_matrix)
    epoch_m_precision = compute_mean(epoch_precisions)
    epoch_recalls = compute_recall(confusion_matrix)
    epoch_m_recall = compute_mean(epoch_recalls)

    # print eval iou every epoch
    print('mean iou: {} \n'.format(epoch_m_iou))
    for i in range(8):
        print_string = "class '{}' iou : {:.4f} \n".format(i, epoch_ious[i])
        print(print_string)

    # make log string
    log_values = [epoch, eval_loss, epoch_m_iou] + \
                 epoch_ious + [epoch_m_precision] + \
                 epoch_precisions + [epoch_m_recall] + epoch_recalls
    log_values = [str(v) for v in log_values]
    log_string = ','.join(log_values)
    eval_log.write(log_string + '\n')
    eval_log.flush()
Example #19
0
def main():

    # using multi process to load data when cuda is available
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    # set image augment
    augments = [ImageAug(), ToTensor()]
    # get dataset and iterable dataloader
    train_dataset = LaneSegTrainDataset("train.csv",
                                        transform=transforms.Compose(augments))
    eval_dataset = LaneSegTrainDataset("eval.csv",
                                       transform=transforms.Compose(
                                           [ToTensor()]))
    if cfg.MULTI_GPU:
        train_batch_size = cfg.TRAIN_BATCH_SIZE * len(cfg.DEVICE_LIST)
        eval_batch_size = cfg.EVAL_BATCH_SIZE * len(cfg.DEVICE_LIST)

    else:
        train_batch_size = cfg.TRAIN_BATCH_SIZE
        eval_batch_size = cfg.EVAL_BATCH_SIZE
    train_data_batch = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)
    eval_data_batch = DataLoader(eval_dataset,
                                 batch_size=eval_batch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 **kwargs)

    # define model
    if cfg.MODEL == 'deeplabv3+':
        net = Deeplabv3plus(class_num=cfg.CLASS_NUM, normal=cfg.NORMAL)
    elif cfg.MODEL == 'unet':
        net = UNetv1(class_num=cfg.CLASS_NUM, normal=cfg.NORMAL)
    else:
        net = UNetv1(class_num=cfg.CLASS_NUM, normal=cfg.NORMAL)

    # use cuda if available
    if torch.cuda.is_available():
        if cfg.MULTI_GPU:
            net = torch.nn.DataParallel(net, device_ids=cfg.DEVICE_LIST)
            net = net.cuda(device=cfg.DEVICE_LIST[0])
        else:
            net = net.cuda()
        # load pretrained weights
        if cfg.PRE_TRAINED:
            checkpoint = torch.load(
                os.path.join(cfg.LOG_DIR, cfg.PRE_TRAIN_WEIGHTS))
            net.load_state_dict(checkpoint['state_dict'])

    # define optimizer
    # optimizer = torch.optim.SGD(net.parameters(),
    #                             lr=cfg.BASE_LR,
    #                             momentum=0.9,
    #                             weight_decay=cfg.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=cfg.BASE_LR,
                                 weight_decay=cfg.WEIGHT_DECAY)
    # criterion = CrossEntropyLoss(cfg.CLASS_NUM)
    criterion = FocalLoss(cfg.CLASS_NUM)
    # define log file
    train_log = open(
        os.path.join(cfg.LOG_DIR, "train_log_{}.csv".format(cfg.TRAIN_NUMBER)),
        'w')
    train_log_title = "epoch,average loss\n"
    train_log.write(train_log_title)
    train_log.flush()
    eval_log = open(
        os.path.join(cfg.LOG_DIR, "eval_log_{}.csv".format(cfg.TRAIN_NUMBER)),
        'w')
    eval_log_title = "epoch,average_loss, mean_iou, iou_0,iou_1,iou_2,iou_3,iou_4,iou_5,iou_6,iou_7, " \
                     "mean_precision, precision_0,precision_1,precision_2,precision_3, precision_4," \
                     "precision_5,precision_6,precision_7, mean_recall, recall_0,recall_1,recall_2," \
                     "recall_3,recall_4,recall_5,recall_6,recall_7\n"
    eval_log.write(eval_log_title)
    eval_log.flush()

    # train and test epoch by epoch
    for epoch in range(cfg.EPOCHS):
        print('current epoch learning rate: {}'.format(cfg.BASE_LR))
        train_epoch(net, epoch, train_data_batch, optimizer, criterion,
                    train_log)
        # save model
        if epoch != cfg.EPOCHS - 1:
            torch.save({'state_dict': net.state_dict()},
                       os.path.join(
                           cfg.LOG_DIR,
                           "laneNet_{0}_{1}th_epoch_{2}.pt".format(
                               cfg.MODEL, cfg.TRAIN_NUMBER, epoch)))
        else:
            torch.save({'state_dict': net.state_dict()},
                       os.path.join(
                           cfg.LOG_DIR, "laneNet_{0}_{1}th.pt".format(
                               cfg.MODEL, cfg.TRAIN_NUMBER)))
        eval_epoch(net, epoch, eval_data_batch, eval_log)

    train_log.close()
    eval_log.close()
def train(model_name, optim='adam'):
    train_dataset = PretrainDataset(output_shape=config['image_resolution'])
    train_loader = DataLoader(train_dataset,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)

    val_dataset = IDRND_dataset_CV(fold=0,
                                   mode=config['mode'].replace('train', 'val'),
                                   double_loss_mode=True,
                                   output_shape=config['image_resolution'])
    val_loader = DataLoader(val_dataset,
                            batch_size=config['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)

    if model_name == 'EF':
        model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained(
            'efficientnet-b3')).to(device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.0090592697255896_1.0.pth"
            ))
    elif model_name == 'EFGAP':
        model = DoubleLossModelTwoHead(
            base_model=EfficientNetGAP.from_pretrained('efficientnet-b3')).to(
                device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.3281182915644134_1.0.pth"
            ))

    criterion = FocalLoss(add_weight=False).to(device)
    criterion4class = CrossEntropyLoss().to(device)

    if optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['learning_rate'],
                                     weight_decay=config['weight_decay'])
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=False)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    momentum=0.9,
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=True)

    steps_per_epoch = train_loader.__len__() - 15
    swa = SWA(optimizer,
              swa_start=config['swa_start'] * steps_per_epoch,
              swa_freq=int(config['swa_freq'] * steps_per_epoch),
              swa_lr=config['learning_rate'] / 10)
    scheduler = ExponentialLR(swa, gamma=0.9)
    # scheduler = StepLR(swa, step_size=5*steps_per_epoch, gamma=0.5)

    global_step = 0
    for epoch in trange(10):
        if epoch < 5:
            scheduler.step()
            continue
        model.train()
        train_bar = tqdm(train_loader)
        train_bar.set_description_str(desc=f"N epochs - {epoch}")

        for step, batch in enumerate(train_bar):
            global_step += 1
            image = batch['image'].to(device)
            label4class = batch['label0'].to(device)
            label = batch['label1'].to(device)

            output4class, output = model(image)
            loss4class = criterion4class(output4class, label4class)
            loss = criterion(output.squeeze(), label)
            swa.zero_grad()
            total_loss = loss4class * 0.5 + loss * 0.5
            total_loss.backward()
            swa.step()
            train_writer.add_scalar(tag="learning_rate",
                                    scalar_value=scheduler.get_lr()[0],
                                    global_step=global_step)
            train_writer.add_scalar(tag="BinaryLoss",
                                    scalar_value=loss.item(),
                                    global_step=global_step)
            train_writer.add_scalar(tag="SoftMaxLoss",
                                    scalar_value=loss4class.item(),
                                    global_step=global_step)
            train_bar.set_postfix_str(f"Loss = {loss.item()}")
            try:
                train_writer.add_scalar(tag="idrnd_score",
                                        scalar_value=idrnd_score_pytorch(
                                            label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="far_score",
                                        scalar_value=far_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="frr_score",
                                        scalar_value=frr_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="accuracy",
                                        scalar_value=bce_accuracy(
                                            label, output),
                                        global_step=global_step)
            except Exception:
                pass

        if (epoch > config['swa_start']
                and epoch % 2 == 0) or (epoch == config['number_epochs'] - 1):
            swa.swap_swa_sgd()
            swa.bn_update(train_loader, model, device)
            swa.swap_swa_sgd()

        scheduler.step()
        evaluate(model, val_loader, epoch, model_name)
Example #21
0
    def __init__(self, **kwargs):
        
        if kwargs['use_gpu']:
            print('Using GPU')
            self.device = t.device('cuda')
        else:
            self.device = t.device('cuda')

        # data
        if not os.path.exists(kwargs['root_dir'] + '/train_image_list'):
            image_list = glob(kwargs['root_dir'] + '/img_train/*')
            image_list = sorted(image_list)

            mask_list = glob(kwargs['root_dir'] + '/lab_train/*')
            mask_list = sorted(mask_list)

        else:
            image_list = open(kwargs['root_dir'] + '/train_image_list', 'r').readlines()
            image_list = [line.strip() for line in image_list]
            image_list = sorted(image_list)
            mask_list = open(kwargs['root_dir'] + '/train_label_list', 'r').readlines()
            mask_list = [line.strip() for line in mask_list]
            mask_list = sorted(mask_list)

        print(image_list[-5:], mask_list[-5:])

        if kwargs['augs']:
            augs = Augment(get_augmentations())
        else:
            augs = None

        self.train_loader, self.val_loader = build_loader(image_list, mask_list, kwargs['test_size'], augs, preprocess, \
            kwargs['num_workers'], kwargs['batch_size'])

        self.model = build_model(kwargs['in_channels'], kwargs['num_classes'], kwargs['model_name']).to(self.device)

        if kwargs['resume']:
            try:
                self.model.load_state_dict(t.load(kwargs['resume']))
                
            except Exception as e:
                self.model.load_state_dict(t.load(kwargs['resume'])['model'])

            print(f'load model from {kwargs["resume"]} successfully')
        if kwargs['loss'] == 'CE':
            self.criterion = nn.CrossEntropyLoss().to(self.device)
        elif kwargs['loss'] == 'FL':
            self.criterion = FocalLoss().to(self.device)
        else:
            raise NotImplementedError
        if kwargs['use_sgd']:
            self.optimizer = SGD(self.model.parameters(), lr=kwargs['lr'], momentum=kwargs['momentum'], nesterov=True, weight_decay=float(kwargs['weight_decay'])) 

        else:
            self.optimizer = Adam(self.model.parameters(), lr=kwargs['lr'], weight_decay=float(kwargs['weight_decay']))

        self.lr_planner = CosineAnnealingWarmRestarts(self.optimizer, 100, T_mult=2, eta_min=1e-6, verbose=True)

        log_dir = os.path.join(kwargs['log_dir'], datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

        self.writer = SummaryWriter(log_dir, comment=f"LR-{kwargs['lr']}_BatchSize-{kwargs['batch_size']}_ModelName-{kwargs['model_name']}")

        self.name = kwargs['model_name']
        self.epoch = kwargs['epoch']

        self.start_epoch = kwargs['start_epoch']

        self.val_interval = kwargs['val_interval']

        self.log_interval = kwargs['log_interval']

        self.num_classes = kwargs['num_classes']

        self.checkpoints = kwargs['checkpoints']

        self.color_dicts = kwargs['color_dicts']


        # , format='%Y-%m-%d %H:%M:%S',
        logging.basicConfig(filename=log_dir + '/log.log', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')

        s = '\n\t\t'
        for k, v in kwargs.items():
            s += f"{k}: \t{v}\n\t\t"

        logging.info(s)
Example #22
0
        else:
            img, label = sample
        iter_batch_size = img.size(0)
        if (iter_batch_size == 1):
            continue
        img = img.permute(0, 3, 1, 2).float() / 255.
        label = label.view(-1, 512, 512).long()
        img, label = img.cuda(), label.cuda()

        size_in = img.size()
        encoder_feat, out = model(img)
        if model_name == "DeepLab":
            #loss = criterion(out,label)
            if boundary_flag:
                seg_loss = FocalLoss(out[:, :-2, :, :],
                                     label,
                                     gamma=1.7,
                                     alpha=0.5)
                #boundary_loss = criterion(out[:,-2:,:,:], boundary_label.long().cuda())
                boundary_loss = FocalLoss(out[:, -2:, :, :],
                                          boundary_label.long().cuda(),
                                          gamma=2.5,
                                          alpha=0.5)
                inter_loss = FocalLoss(encoder_feat,
                                       boundary_label_small.long().cuda(),
                                       gamma=2.5,
                                       alpha=0.5)
                loss = seg_loss + boundary_loss + inter_loss
            else:
                loss = FocalLoss(out, label, gamma=2, alpha=0.5)
        else:
            aux = F.interpolate(out[0],
Example #23
0
    def __init__(self, args):
        self.args = args
        self.mode = args.mode
        self.epochs = args.epochs
        self.dataset = args.dataset
        self.data_path = args.data_path
        self.train_crop_size = args.train_crop_size
        self.eval_crop_size = args.eval_crop_size
        self.stride = args.stride
        self.batch_size = args.train_batch_size
        self.train_data = AerialDataset(crop_size=self.train_crop_size,
                                        dataset=self.dataset,
                                        data_path=self.data_path,
                                        mode='train')
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=2)
        self.eval_data = AerialDataset(dataset=self.dataset,
                                       data_path=self.data_path,
                                       mode='val')
        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=2)

        if self.dataset == 'Potsdam':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(6000, 6000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD5':
            self.num_of_class = 5
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD6':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        else:
            raise NotImplementedError

        if args.model == 'FCN':
            self.model = models.FCN8(num_classes=self.num_of_class)
        elif args.model == 'DeepLabV3+':
            self.model = models.DeepLab(num_classes=self.num_of_class,
                                        backbone='resnet')
        elif args.model == 'GCN':
            self.model = models.GCN(num_classes=self.num_of_class)
        elif args.model == 'UNet':
            self.model = models.UNet(num_classes=self.num_of_class)
        elif args.model == 'ENet':
            self.model = models.ENet(num_classes=self.num_of_class)
        elif args.model == 'D-LinkNet':
            self.model = models.DinkNet34(num_classes=self.num_of_class)
        else:
            raise NotImplementedError

        if args.loss == 'CE':
            self.criterion = CrossEntropyLoss2d()
        elif args.loss == 'LS':
            self.criterion = LovaszSoftmax()
        elif args.loss == 'F':
            self.criterion = FocalLoss()
        elif args.loss == 'CE+D':
            self.criterion = CE_DiceLoss()
        else:
            raise NotImplementedError

        self.schedule_mode = args.schedule_mode
        self.optimizer = opt.AdamW(self.model.parameters(), lr=args.lr)
        if self.schedule_mode == 'step':
            self.scheduler = opt.lr_scheduler.StepLR(self.optimizer,
                                                     step_size=30,
                                                     gamma=0.1)
        elif self.schedule_mode == 'miou' or self.schedule_mode == 'acc':
            self.scheduler = opt.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                mode='max',
                                                                patience=10,
                                                                factor=0.1)
        elif self.schedule_mode == 'poly':
            iters_per_epoch = len(self.train_loader)
            self.scheduler = Poly(self.optimizer,
                                  num_epochs=args.epochs,
                                  iters_per_epoch=iters_per_epoch)
        else:
            raise NotImplementedError

        self.evaluator = Evaluator(self.num_of_class)

        self.model = nn.DataParallel(self.model)

        self.cuda = args.cuda
        if self.cuda is True:
            self.model = self.model.cuda()

        self.resume = args.resume
        self.finetune = args.finetune
        assert not (self.resume != None and self.finetune != None)

        if self.resume != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.resume)
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.start_epoch = checkpoint['epoch'] + 1
            #start from next epoch
        elif self.finetune != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.finetune)
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1
        if self.mode == 'train':
            self.writer = SummaryWriter(comment='-' + self.dataset + '_' +
                                        self.model.__class__.__name__ + '_' +
                                        args.loss)
        self.init_eval = args.init_eval
def train(model, train_dataset, dev_dataset, test_dataset):
    start_time = time.time()
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  sampler=train_sampler,
                                  collate_fn=collate_fn)
    model.train()
    if config.use_cuda:
        model.to(config.device)
    if "bert" not in config.model_name:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.learning_rate)
    else:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.02
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=config.learning_rate,
                          eps=config.eps)

    t_total = len(
        train_dataloader
    ) // config.gradient_accumulation_steps * config.num_train_epochs
    warmup_steps = int(t_total * config.warmup_proportion)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)
    criterion = FocalLoss(gamma=2)
    # Train!
    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", len(train_dataset))
    logging.info("  Num Epochs = %d", config.num_train_epochs)
    logging.info("  batch size = %d", config.batch_size)
    logging.info("  Num batches = %d",
                 config.num_train_epochs * len(train_dataloader))
    logging.info("  device: {}".format(config.device))

    total_batch = 0
    dev_best_acc = float('-inf')
    last_improve = 0
    flag = False

    checkpoints = [
        path for path in os.listdir(config.save_path)
        if path.startswith("checkpoint")
    ]
    if checkpoints:
        print(checkpoints)
        checkpoints = sorted(map(lambda x: os.path.splitext(x)[0].split("_"),
                                 checkpoints),
                             key=lambda x: float(x[2]))[-1]
        dev_best_acc = float(checkpoints[-1]) / 100
        model_path = os.path.join(config.save_path,
                                  "_".join(checkpoints) + ".ckpt")
        model.load_state_dict(torch.load(model_path))
        logging.info("继续训练, {}".format(model_path))
        logging.info("最大准确率: {}".format(dev_best_acc))

    for epoch in range(config.num_train_epochs):
        logging.info('Epoch [{}/{}]'.format(epoch + 1,
                                            config.num_train_epochs))
        for i, batch in enumerate(train_dataloader):
            attention_mask, token_type_ids = torch.tensor([1]), torch.tensor(
                [1])
            if "bert" in config.model_name:
                token_ids, attention_mask, token_type_ids, labels, tokens = batch
            else:
                token_ids, labels, tokens = batch
            if i < 1:
                logging.info("tokens: {}\n ".format(tokens))
                logging.info("token_ids: {}\n ".format(token_ids))
                logging.info("token_ids shape: {}\n ".format(token_ids.shape))
                logging.info("attention_mask: {}\n".format(attention_mask))
                logging.info("token_type_ids: {}\n".format(token_type_ids))
                logging.info("labels: {}\n".format(labels))
                logging.info("labels shape: {}\n".format(labels.shape))
            if config.use_cuda:
                token_ids = token_ids.to(config.device)
                attention_mask = attention_mask.to(config.device)
                token_type_ids = token_type_ids.to(config.device)
                labels = labels.to(config.device)
            outputs = model((token_ids, attention_mask, token_type_ids))
            # loss = F.cross_entropy(outputs, labels)
            loss = criterion(outputs, labels)
            if config.gradient_accumulation_steps > 1:
                loss = loss / config.gradient_accumulation_steps

            total_batch += 1
            if total_batch % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_grad_norm)
                loss.backward()
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()

            if total_batch % config.logging_step == 0:
                true = labels.data.cpu()
                predicts = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predicts)
                output = evaluate(model, dev_dataset)
                dev_acc, dev_loss = output[0], output[1]
                if dev_acc > dev_best_acc:
                    logging.info("saving model..........")
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            config.save_path,
                            "checkpoint_{}_{:.2f}.ckpt".format(
                                total_batch, dev_acc * 100)))
                    improve = '*'
                    last_improve = total_batch
                    dev_best_acc = dev_acc
                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter:{0}, Train Loss:{1:.4f}, Train Acc:{2:.2f}%, Val Loss:{3:.4f}, Val Acc:{4:.2f}%, Time:{5} {6}'
                logging.info(
                    msg.format(total_batch, loss.item(), train_acc * 100,
                               dev_loss, dev_acc * 100, time_dif, improve))
            if total_batch - last_improve > config.patient:
                logging.info(
                    "No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
    torch.save(model.state_dict(), os.path.join(config.save_path,
                                                "model.ckpt"))
    if config.predict:
        predict(model, test_dataset)
Example #25
0
# img_list = glob(image_dir + '/*')

# img_list = sorted(img_list)
# lab_list = glob(label_dir + '/*')

# lab_list = sorted(lab_list)

# dataset = SatImageDataset(img_list, preprocess=preprocess)
# dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# tbar = tqdm(enumerate(dataloader), desc='infer')
# for i, (img_path, img) in tbar:
#     if i > 5:
#         break
#     print(len(img_path), img_path[1], img.shape)
# writer = SummaryWriter('./temp_logs')
# print(content['data']['color_dicts'])
# for i, (image, label) in enumerate(dataloader):
#     if i > 5:
#         break

#     label = vis_mask(label, content['data']['color_dicts'])
#     label = make_grid(label, nrow=2, padding=10)
#     writer.add_image('img', label, i)

criterion = FocalLoss().cuda()
a = t.rand(4, 3, 256, 256).cuda()
b = t.randint(0, 3, (4, 256, 256)).type(t.long).cuda()
loss = criterion(a, b)
print(loss)
Example #26
0
path_g = os.path.join(model_path, args.path_g)
path_g2l = os.path.join(model_path, args.path_g2l)
path_l2g = os.path.join(model_path, args.path_l2g)
model, global_fixed = create_model_load_weights(n_class, mode, evaluation, path_g=path_g, path_g2l=path_g2l, path_l2g=path_l2g)

###################################
num_epochs = args.num_epochs
learning_rate = args.lr
lamb_fmreg = args.lamb_fmreg

optimizer = get_optimizer(model, mode, learning_rate=learning_rate)

scheduler = LR_Scheduler('poly', learning_rate, num_epochs, len(dataloader_train))
##################################

criterion1 = FocalLoss(gamma=3)
criterion2 = nn.CrossEntropyLoss()
criterion3 = lovasz_softmax
criterion = lambda x,y: criterion1(x, y)
# criterion = lambda x,y: 0.5*criterion1(x, y) + 0.5*criterion3(x, y)
mse = nn.MSELoss()

if not evaluation:
    
    writer = SummaryWriter(log_dir=os.path.join(log_path, task_name))
    f_log = open(os.path.join(log_path, task_name + ".log"), 'w')

trainer = Trainer(criterion, optimizer, n_class, size_g, size_p, sub_batch_size, mode, lamb_fmreg)
evaluator = Evaluator(n_class, size_g, size_p, sub_batch_size, mode, test)

best_pred = 0.0
Example #27
0
    def train():
        G.train()
        F1.train()
        optimizer_g = optim.SGD(params,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)
        optimizer_f = optim.SGD(list(F1.parameters()),
                                lr=1.0,
                                momentum=0.9,
                                weight_decay=0.0005,
                                nesterov=True)

        # Loading the states of the two optmizers
        optimizer_g.load_state_dict(main_dict['optimizer_g'])
        optimizer_f.load_state_dict(main_dict['optimizer_f'])
        print("Loaded optimizer states")

        def zero_grad_all():
            optimizer_g.zero_grad()
            optimizer_f.zero_grad()

        param_lr_g = []
        for param_group in optimizer_g.param_groups:
            param_lr_g.append(param_group["lr"])
        param_lr_f = []
        for param_group in optimizer_f.param_groups:
            param_lr_f.append(param_group["lr"])

        # Setting the loss function to be used for the classification loss
        if args.loss == 'CE':
            criterion = nn.CrossEntropyLoss().to(device)
        if args.loss == 'FL':
            criterion = FocalLoss(alpha=1, gamma=1).to(device)
        if args.loss == 'CBFL':
            # Calculating the list having the number of examples per class which is going to be used in the CB focal loss
            beta = 0.99
            effective_num = 1.0 - np.power(beta, class_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                class_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
            criterion = CBFocalLoss(weight=per_cls_weights,
                                    gamma=0.5).to(device)

        all_step = args.steps
        data_iter_s = iter(source_loader)
        data_iter_t = iter(target_loader)
        data_iter_t_unl = iter(target_loader_unl)
        len_train_source = len(source_loader)
        len_train_target = len(target_loader)
        len_train_target_semi = len(target_loader_unl)
        best_acc = 0
        counter = 0
        for step in range(all_step):
            optimizer_g = inv_lr_scheduler(param_lr_g,
                                           optimizer_g,
                                           step,
                                           init_lr=args.lr)
            optimizer_f = inv_lr_scheduler(param_lr_f,
                                           optimizer_f,
                                           step,
                                           init_lr=args.lr)
            lr = optimizer_f.param_groups[0]['lr']
            # condition for restarting the iteration for each of the data loaders
            if step % len_train_target == 0:
                data_iter_t = iter(target_loader)
            if step % len_train_target_semi == 0:
                data_iter_t_unl = iter(target_loader_unl)
            if step % len_train_source == 0:
                data_iter_s = iter(source_loader)
            data_t = next(data_iter_t)
            data_t_unl = next(data_iter_t_unl)
            data_s = next(data_iter_s)
            with torch.no_grad():
                im_data_s.resize_(data_s[0].size()).copy_(data_s[0])
                gt_labels_s.resize_(data_s[1].size()).copy_(data_s[1])
                im_data_t.resize_(data_t[0].size()).copy_(data_t[0])
                gt_labels_t.resize_(data_t[1].size()).copy_(data_t[1])
                im_data_tu.resize_(data_t_unl[0].size()).copy_(data_t_unl[0])

            zero_grad_all()
            data = torch.cat((im_data_s, im_data_t), 0)
            target = torch.cat((gt_labels_s, gt_labels_t), 0)
            output = G(data)
            out1 = F1(output)
            loss = criterion(out1, target)
            loss.backward(retain_graph=True)
            optimizer_g.step()
            optimizer_f.step()
            zero_grad_all()
            # list of the weights and image paths in this batch
            img_paths = list(data_t_unl[2])
            df1 = df.loc[df['img'].isin(img_paths)]
            df1 = df1['weight']
            weight_list = list(df1)

            if not args.method == 'S+T':
                output = G(im_data_tu)
                if args.method == 'ENT':
                    loss_t = entropy(F1, output, args.lamda)
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                elif args.method == 'MME':
                    loss_t = adentropy(F1, output, args.lamda, weight_list)
                    loss_t.backward()
                    optimizer_f.step()
                    optimizer_g.step()
                else:
                    raise ValueError('Method cannot be recognized.')
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Loss T {:.6f} ' \
                            'Method {}\n'.format(args.source, args.target,
                                                step, lr, loss.data,
                                                -loss_t.data, args.method)
            else:
                log_train = 'S {} T {} Train Ep: {} lr{} \t ' \
                            'Loss Classification: {:.6f} Method {}\n'.\
                    format(args.source, args.target,
                        step, lr, loss.data,
                        args.method)
            G.zero_grad()
            F1.zero_grad()
            zero_grad_all()
            if step % args.log_interval == 0:
                print(log_train)
            if step % args.save_interval == 0 and step > 0:
                loss_val, acc_val = test(target_loader_val)
                loss_test, acc_test = test(target_loader_test)
                G.train()
                F1.train()
                if acc_test >= best_acc:
                    best_acc = acc_test
                    best_acc_test = acc_test
                    counter = 0
                else:
                    counter += 1
                if args.early:
                    if counter > args.patience:
                        break
                print('best acc test %f best acc val %f' %
                      (best_acc_test, acc_val))
                print('record %s' % record_file)
                with open(record_file, 'a') as f:
                    f.write('step %d best %f final %f \n' %
                            (step, best_acc_test, acc_val))
                G.train()
                F1.train()
                #saving model as a checkpoint dict having many things
                if args.save_check:
                    print('saving model')
                    is_best = True if counter == 0 else False
                    save_mymodel(
                        args, {
                            'step': step,
                            'arch': args.net,
                            'G_state_dict': G.state_dict(),
                            'F1_state_dict': F1.state_dict(),
                            'best_acc_test': best_acc_test,
                            'optimizer_g': optimizer_g.state_dict(),
                            'optimizer_f': optimizer_f.state_dict(),
                        }, is_best)