コード例 #1
0
    def training(self, epoch):
        train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum, ins_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        target_train_iterator = iter(self.target_train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            if self.visdom:
                self.vis.line(
                    X=torch.tensor([itr]),
                    Y=torch.tensor([self.optimizer.param_groups[0]['lr']]),
                    win='lr',
                    opts=dict(title='lr', xlabel='iter', ylabel='lr'),
                    update='append' if itr > 0 else None)
            A_image, A_target = sample['image'], sample['label']

            # Get one batch from target domain
            try:
                target_sample = next(target_train_iterator)
            except StopIteration:
                target_train_iterator = iter(self.target_train_loader)
                target_sample = next(target_train_iterator)

            B_image, B_target, B_image_pair = target_sample[
                'image'], target_sample['label'], target_sample['image_pair']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()
                B_image, B_target, B_image_pair = B_image.cuda(
                ), B_target.cuda(), B_image_pair.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target)
            #self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source, self.best_pred_target)

            A_output, A_feat, A_low_feat = self.model(A_image)
            B_output, B_feat, B_low_feat = self.model(B_image)
            B_output_pair, B_feat_pair, B_low_feat_pair = self.model(
                B_image_pair)
            B_output_pair, B_feat_pair, B_low_feat_pair = flip(
                B_output_pair, dim=-1), flip(B_feat_pair,
                                             dim=-1), flip(B_low_feat_pair,
                                                           dim=-1)

            self.optimizer.zero_grad()
            #self.D_optimizer.zero_grad()

            # Train seg network
            #for param in self.D.parameters():
            #    param.requires_grad = False

            # Supervised loss
            seg_loss = self.criterion(A_output, A_target)
            ins_loss = 0.1 * self.instance_loss(B_output, B_output_pair)
            # Unsupervised bn loss
            main_loss = seg_loss + ins_loss
            main_loss.backward()
            # Train adversarial loss
            #D_out = self.D(prob_2_entropy(F.softmax(B_output)))
            #adv_loss = bce_loss(D_out, self.source_label)
            #main_loss += self.config.lambda_adv * adv_loss
            #main_loss.backward()

            # Train discriminator
            #for param in self.D.parameters():
            #    param.requires_grad = True
            #A_output_detach = A_output.detach()
            #B_output_detach = B_output.detach()
            # source
            #D_source = self.D(prob_2_entropy(F.softmax(A_output_detach)))
            #source_loss = bce_loss(D_source, self.source_label)
            #source_loss = source_loss / 2
            # target
            #D_target = self.D(prob_2_entropy(F.softmax(B_output_detach)))
            #target_loss = bce_loss(D_target, self.target_label)
            #target_loss = target_loss / 2
            #d_loss = source_loss + target_loss
            #d_loss.backward()

            self.optimizer.step()
            #self.D_optimizer.step()

            seg_loss_sum += seg_loss.item()
            ins_loss_sum += ins_loss.item()
            #bn_loss_sum += bottleneck_loss.item()
            #adv_loss_sum += self.config.lambda_adv * adv_loss.item()
            #d_loss_sum += d_loss.item()

            #train_loss += seg_loss.item() + self.config.lambda_adv * adv_loss.item()
            train_loss += seg_loss.item() + ins_loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        #print('Loss: %.3f' % train_loss)
        print('Seg Loss: %.3f' % seg_loss_sum)
        print('Ins Loss: %.3f' % ins_loss_sum)
        #print('BN Loss: %.3f' % bn_loss_sum)
        #print('Adv Loss: %.3f' % adv_loss_sum)
        #print('Discriminator Loss: %.3f' % d_loss_sum)

        if self.visdom:
            self.vis.line(X=torch.tensor([epoch]),
                          Y=torch.tensor([seg_loss_sum]),
                          win='train_loss',
                          name='Seg_loss',
                          opts=dict(title='loss',
                                    xlabel='epoch',
                                    ylabel='loss'),
                          update='append' if epoch > 0 else None)
            self.vis.line(X=torch.tensor([epoch]),
                          Y=torch.tensor([ins_loss_sum]),
                          win='train_loss',
                          name='Ins_loss',
                          opts=dict(title='loss',
                                    xlabel='epoch',
                                    ylabel='loss'),
                          update='append' if epoch > 0 else None)
コード例 #2
0
ファイル: train.py プロジェクト: Oliver-ss/DomainAdaptation
    def training(self, epoch):
        train_loss, seg_loss_sum, bn_loss_sum, entropy_loss_sum, adv_loss_sum, d_loss_sum, ins_loss_sum = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        self.model.train()
        if config.freeze_bn:
            self.model.module.freeze_bn()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        target_train_iterator = iter(self.target_train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            #if self.visdom:
            #    self.vis.line(X=torch.tensor([itr]), Y=torch.tensor([self.optimizer.param_groups[0]['lr']]),
            #              win='lr', opts=dict(title='lr', xlabel='iter', ylabel='lr'),
            #              update='append' if itr>0 else None)
            self.summary.writer.add_scalar(
                'Train/lr', self.optimizer.param_groups[0]['lr'], itr)
            A_image, A_target = sample['image'], sample['label']

            # Get one batch from target domain
            try:
                target_sample = next(target_train_iterator)
            except StopIteration:
                target_train_iterator = iter(self.target_train_loader)
                target_sample = next(target_train_iterator)

            B_image, B_target, B_image_pair = target_sample[
                'image'], target_sample['label'], target_sample['image_pair']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()
                B_image, B_target, B_image_pair = B_image.cuda(
                ), B_target.cuda(), B_image_pair.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)
            self.scheduler(self.D_optimizer, i, epoch, self.best_pred_source,
                           self.best_pred_target, self.config.lr_ratio)

            A_output, A_feat, A_low_feat = self.model(A_image)
            B_output, B_feat, B_low_feat = self.model(B_image)
            B_output_pair, B_feat_pair, B_low_feat_pair = self.model(
                B_image_pair)
            B_output_pair, B_feat_pair, B_low_feat_pair = flip(
                B_output_pair, dim=-1), flip(B_feat_pair,
                                             dim=-1), flip(B_low_feat_pair,
                                                           dim=-1)

            self.optimizer.zero_grad()
            self.D_optimizer.zero_grad()

            # Train seg network
            for param in self.D.parameters():
                param.requires_grad = False

            # Supervised loss
            seg_loss = self.criterion(A_output, A_target)
            main_loss = seg_loss

            # Unsupervised loss
            ins_loss = 0.01 * self.instance_loss(B_output, B_output_pair)
            main_loss += ins_loss

            # Train adversarial loss
            D_out = self.D(F.softmax(B_output))
            adv_loss = bce_loss(D_out, self.source_label)

            main_loss += self.config.lambda_adv * adv_loss
            main_loss.backward()

            # Train discriminator
            for param in self.D.parameters():
                param.requires_grad = True
            A_output_detach = A_output.detach()
            B_output_detach = B_output.detach()
            # source
            D_source = self.D(F.softmax(A_output_detach))
            source_loss = bce_loss(D_source, self.source_label)
            source_loss = source_loss / 2
            # target
            D_target = self.D(F.softmax(B_output_detach))
            target_loss = bce_loss(D_target, self.target_label)
            target_loss = target_loss / 2
            d_loss = source_loss + target_loss
            d_loss.backward()

            self.optimizer.step()
            self.D_optimizer.step()

            seg_loss_sum += seg_loss.item()
            ins_loss_sum += ins_loss.item()
            adv_loss_sum += self.config.lambda_adv * adv_loss.item()
            d_loss_sum += d_loss.item()

            #train_loss += seg_loss.item() + self.config.lambda_adv * adv_loss.item()
            train_loss += seg_loss.item()
            self.summary.writer.add_scalar('Train/SegLoss', seg_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/InsLoss', ins_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/AdvLoss', adv_loss.item(),
                                           itr)
            self.summary.writer.add_scalar('Train/DiscriminatorLoss',
                                           d_loss.item(), itr)
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            # Show the results of the last iteration
            #if i == len(self.train_loader)-1:
        print("Add Train images at epoch" + str(epoch))
        self.summary.visualize_image('Train-Source', self.config.dataset,
                                     A_image, A_target, A_output, epoch, 5)
        self.summary.visualize_image('Train-Target', self.config.target,
                                     B_image, B_target, B_output, epoch, 5)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)