示例#1
0
文件: train.py 项目: lvcat/LEDNet
    def training(self):
        self.net.train()
        save_to_disk = ptutil.get_rank() == 0
        start_training_time = time.time()
        trained_time = 0
        tic = time.time()
        end = time.time()
        iteration, max_iter = 0, self.args.max_iter
        save_iter, eval_iter = self.args.per_iter * self.args.save_epoch, self.args.per_iter * self.args.eval_epochs
        # save_iter, eval_iter = 10, 10

        logger.info(
            "Start training, total epochs {:3d} = total iteration: {:6d}".
            format(self.args.epochs, max_iter))

        for i, (image, target) in enumerate(self.train_loader):
            iteration += 1
            self.scheduler.step()
            self.optimizer.zero_grad()
            image, target = image.to(self.device), target.to(self.device)
            outputs = self.net(image)
            loss_dict = self.criterion(outputs, target)
            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = ptutil.reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            loss = sum(loss for loss in loss_dict.values())
            loss.backward()
            self.optimizer.step()
            trained_time += time.time() - end
            end = time.time()
            if iteration % args.log_step == 0:
                eta_seconds = int(
                    (trained_time / iteration) * (max_iter - iteration))
                log_str = [
                    "Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}".
                    format(iteration, self.optimizer.param_groups[0]['lr'],
                           time.time() - tic,
                           str(datetime.timedelta(seconds=eta_seconds))),
                    "total_loss: {:.3f}".format(losses_reduced.item())
                ]
                log_str = ', '.join(log_str)
                logger.info(log_str)
                tic = time.time()
            if save_to_disk and iteration % save_iter == 0:
                model_path = os.path.join(
                    self.args.save_dir,
                    "{}_iter_{:06d}.pth".format('LEDNet', iteration))
                self.save_model(model_path)
            # Do eval when training, to trace the mAP changes and see performance improved whether or nor
            if args.eval_epochs > 0 and iteration % eval_iter == 0 and not iteration == max_iter:
                metrics = self.validate()
                ptutil.synchronize()
                pixAcc, mIoU = ptutil.accumulate_metric(metrics)
                if pixAcc is not None:
                    logger.info('pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                        pixAcc, mIoU))
                self.net.train()
        if save_to_disk:
            model_path = os.path.join(
                self.args.save_dir,
                "{}_iter_{:06d}.pth".format('LEDNet', max_iter))
            self.save_model(model_path)
        # compute training time
        total_training_time = int(time.time() - start_training_time)
        total_time_str = str(datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / max_iter))
        # eval after training
        if not self.args.skip_eval:
            metrics = self.validate()
            ptutil.synchronize()
            pixAcc, mIoU = ptutil.accumulate_metric(metrics)
            if pixAcc is not None:
                logger.info(
                    'After training, pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                        pixAcc, mIoU))
示例#2
0
文件: base.py 项目: JarvisLL/ACE
    def training(self):
        self.net.train()
        save_to_disk = ptutil.get_rank() == 0
        start_training_time = time.time()
        trained_time = 0
        mIoU = 0
        best_miou = 0
        tic = time.time()
        end = time.time()
        iteration, max_iter = 0, self.max_iter
        save_iter, eval_iter = self.per_iter * self.config.TRAIN.SAVE_EPOCH, self.per_iter * self.config.TRAIN.EVAL_EPOCHS
        self.logger.info("Start training, total epochs {:3d} = total iteration: {:6d}".format(self.config.TRAIN.EPOCHS, max_iter))
        for i, (image, target) in enumerate(self.train_loader):
            iteration += 1
            self.scheduler.step()
            self.optimizer.zero_grad()
            image, target = image.to(self.device,dtype=self.dtype), target.to(self.device)
            if self.config.DATASET.IMG_TRANSFORM == False:
                image = image.permute(0,3,1,2)
            outputs = self.net(image)
            loss_dict = self.criterion(outputs, target)
            loss_dict_reduced = ptutil.reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss = sum(loss for loss in loss_dict.values())
            if self.config.TRAIN.MIXED_PRECISION:
                with amp.scale_loss(loss,self.optimizer) as scale_loss:
                    scale_loss.backward()
            else:
                loss.backward()

            self.optimizer.step()
            trained_time += time.time() - end
            end = time.time()
            if iteration % self.config.TRAIN.LOG_STEP == 0:
                eta_seconds = int((trained_time / iteration) * (max_iter - iteration))
                log_str = ["Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}"
                               .format(iteration, self.optimizer.param_groups[0]['lr'], time.time() - tic,
                                       str(datetime.timedelta(seconds=eta_seconds))),
                           "total_loss: {:.3f}".format(losses_reduced.item())]
                log_str = ', '.join(log_str)
                self.logger.info(log_str)
                tic = time.time()
            if save_to_disk and iteration % save_iter == 0:
                model_path = os.path.join(self.config.TRAIN.SAVE_DIR, "{}_{}_{}_iter_{:06d}.pth"
                                          .format(self.config.MODEL.NAME, self.config.TRAIN.SEG_LOSS, self.config.DATASET.NAME, iteration))
                ptutil.save_model(self.net,model_path,self.logger)
            if self.config.TRAIN.EVAL_EPOCHS > 0 and iteration % eval_iter == 0 and not iteration == max_iter:
                metrics = ptutil.validate(self.net,self.valid_loader,self.metric,self.device,self.config)
                ptutil.synchronize()
                pixAcc, mIoU = ptutil.accumulate_metric(metrics)
                if mIoU !=None and mIoU >= best_miou:
                    best_miou = mIoU
                    model_path = os.path.join(self.config.TRAIN.SAVE_DIR, "{}_{}_{}_best.pth"
                                          .format(self.config.MODEL.NAME, self.config.TRAIN.SEG_LOSS, self.config.DATASET.NAME))
                    ptutil.save_model(self.net,model_path,self.logger)
                if pixAcc is not None:
                    self.logger.info('pixAcc: {:.4f}, mIoU: {:.4f}'.format(pixAcc, mIoU))
                self.net.train()
        if save_to_disk:
            model_path = os.path.join(self.config.TRAIN.SAVE_DIR, "{}_{}_{}_iter_{:06d}.pth"
                                      .format(self.config.MODEL.NAME, self.config.TRAIN.SEG_LOSS, self.config.DATASET.NAME, max_iter))
            ptutil.save_model(self.net,model_path,self.logger)
        total_training_time = int(time.time() - start_training_time)
        total_time_str = str(datetime.timedelta(seconds=total_training_time))
        self.logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / max_iter))
        # eval after training
        if not self.config.TRAIN.SKIP_EVAL:
            metrics = ptutil.validate(self.net,self.valid_loader,self.metric,self.device,self.config)
            ptutil.synchronize()
            pixAcc, mIoU = ptutil.accumulate_metric(metrics)
            if pixAcc is not None:
                self.logger.info('After training, pixAcc: {:.4f}, mIoU: {:.4f}'.format(pixAcc, mIoU))
示例#3
0
    def training(self):
        self.net.train()
        save_to_disk = ptutil.get_rank() == 0
        start_training_time = time.time()
        trained_time = 0
        tic = time.time()
        end = time.time()
        iteration, max_iter = 0, self.args.max_iter
        save_iter, eval_iter = self.args.per_iter * self.args.save_epoch, self.args.per_iter * self.args.eval_epoch
        # save_iter, eval_iter = self.args.per_iter * self.args.save_epoch, 10

        logger.info("Start training, total epochs {:3d} = total iteration: {:6d}".format(self.args.epochs, max_iter))

        # TODO: add mixup
        for i, batch in enumerate(self.train_loader):
            iteration += 1
            self.scheduler.step()
            image = batch[0].to(self.device)
            fixed_targets = [batch[it].to(self.device) for it in range(1, 6)]
            gt_boxes = batch[6].to(self.device)

            self.optimizer.zero_grad()
            loss_dict = self.net(image, gt_boxes, *fixed_targets)
            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = ptutil.reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            loss = sum(loss for loss in loss_dict.values())
            loss.backward()
            self.optimizer.step()
            trained_time += time.time() - end
            end = time.time()
            if iteration % args.log_step == 0:
                eta_seconds = int((trained_time / iteration) * (max_iter - iteration))
                log_str = ["Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}"
                               .format(iteration, self.optimizer.param_groups[0]['lr'], time.time() - tic,
                                       str(datetime.timedelta(seconds=eta_seconds))),
                           "total_loss: {:.3f}".format(losses_reduced.item())]
                for loss_name, loss_item in loss_dict_reduced.items():
                    log_str.append("{}: {:.3f}".format(loss_name, loss_item.item()))
                log_str = ', '.join(log_str)
                logger.info(log_str)
                tic = time.time()
            if save_to_disk and iteration % save_iter == 0:
                model_path = os.path.join(self.args.save_dir, "{}_iter_{:06d}.pth"
                                          .format(self.save_prefix, iteration))
                self.save_model(model_path)
            # Do eval when training, to trace the mAP changes and see performance improved whether or nor
            if self.args.eval_epoch > 0 and iteration % eval_iter == 0 and not iteration == max_iter:
                metrics = self.validate()
                ptutil.synchronize()
                names, values = ptutil.accumulate_metric(metrics)
                if names is not None:
                    log_str = ['{}: {:.5f}'.format(k, v) for k, v in zip(names, values)]
                    log_str = '\n'.join(log_str)
                    logger.info(log_str)
                self.net.train()
        if save_to_disk:
            model_path = os.path.join(self.args.save_dir, "{}_iter_{:06d}.pth"
                                      .format(self.save_prefix, max_iter))
            self.save_model(model_path)

        # compute training time
        total_training_time = int(time.time() - start_training_time)
        total_time_str = str(datetime.timedelta(seconds=total_training_time))
        logger.info(
            "Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / max_iter))
示例#4
0
    def training(self):
        self.seg_net.train()
        self.generator.train()
        self.feature_extracted.eval()
        for param in self.feature_extracted.parameters():
            param.requires_grad = False

        save_to_disk = ptutil.get_rank() == 0
        start_training_time = time.time()
        trained_time = 0
        best_miou = 0
        mean = torch.tensor([0.485, 0.456,
                             0.406]).float().cuda().view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224,
                            0.225]).float().cuda().view(1, 3, 1, 1)
        tic = time.time()
        end = time.time()
        iteration, max_iter = 0, self.max_iter
        save_iter, eval_iter = self.per_iter * self.config.TRAIN.SAVE_EPOCH, self.per_iter * self.config.TRAIN.EVAL_EPOCH
        # save_iter, eval_iter = 10, 10
        self.logger.info(
            "Start training, total epochs {:3d} = total iteration: {:6d}".
            format(self.config.TRAIN.EPOCHS, max_iter))
        for i, (source_image, label) in enumerate(self.train_loader):
            iteration += 1
            self.scheduler.step()
            # self.optimizer.zero_grad()
            self.gen_scheduler.step()
            # self.gen_optimizer.zero_grad()
            source_image, label = source_image.to(self.device,
                                                  dtype=self.dtype), label.to(
                                                      self.device)
            try:
                _, batch = self.target_trainloader_iter.__next__()
            except:
                self.target_trainloader_iter = enumerate(self.target_loader)
                _, batch = self.target_trainloader_iter.__next__()
            target_image = batch.to(self.device, dtype=self.dtype)
            if self.config.DATASET.IMG_TRANSFORM == False:
                source_image = source_image.permute(0, 3, 1, 2)
                target_image = target_image.permute(0, 3, 1, 2)
                source_image_norm = (((source_image / 255) - mean) / std)
                target_image_norm = (((target_image / 255) - mean) / std)
            else:
                source_image_norm = source_image
                target_image_norm = target_image
            source_feature = self.feature_extracted(source_image_norm)
            target_feature = self.feature_extracted(target_image_norm)

            target_feature_mean = torch.mean(target_feature, (2, 3),
                                             keepdim=True)
            target_feature_var = torch.std(target_feature, (2, 3),
                                           keepdim=True)
            source_feature_mean = torch.mean(source_feature, (2, 3),
                                             keepdim=True)
            source_feature_var = torch.std(source_feature, (2, 3),
                                           keepdim=True)

            adain_feature = (
                (source_feature - source_feature_mean) /
                (source_feature_var + 0.00001)) * (
                    target_feature_var + 0.00001) + target_feature_mean
            gen_image_norm = self.generator(adain_feature)
            gen_image = ((gen_image_norm * std) + mean) * 255

            gen_image_feature = self.feature_extracted(gen_image_norm)
            gen_image_feature_mean = torch.mean(gen_image_feature, (2, 3),
                                                keepdim=True)
            gen_image_feature_var = torch.std(gen_image_feature, (2, 3),
                                              keepdim=True)
            #adain_feature <--> gen_image_feature gen_image_feature gen_image_feature_mean <--> target_feature_mean
            #gen_image_feature_var <--> target_feature_var
            loss_feature_dict = self.gen_criterion(gen_image_feature,
                                                   adain_feature)
            loss_mean_dict = self.gen_criterion(gen_image_feature_mean,
                                                target_feature_mean)
            loss_var_dict = self.gen_criterion(gen_image_feature_var,
                                               target_feature_var)

            loss_feature = sum(loss for loss in loss_feature_dict.values())
            loss_feature_dict_reduced = ptutil.reduce_loss_dict(
                loss_feature_dict)
            loss_feature_reduced = sum(
                loss for loss in loss_feature_dict_reduced.values())

            loss_mean = sum(loss for loss in loss_mean_dict.values())
            loss_mean_dict_reduced = ptutil.reduce_loss_dict(loss_mean_dict)
            loss_mean_reduced = sum(
                loss for loss in loss_mean_dict_reduced.values())

            loss_var = sum(loss for loss in loss_var_dict.values())
            loss_var_dict_reduced = ptutil.reduce_loss_dict(loss_var_dict)
            loss_var_reduced = sum(loss
                                   for loss in loss_var_dict_reduced.values())

            loss_gen = loss_feature + loss_mean + loss_var
            # train source image
            outputs = self.seg_net(source_image)
            source_seg_loss_dict = self.criterion(outputs, label)
            # train gen image
            gen_outputs = self.seg_net(gen_image)
            gen_seg_loss_dict = self.criterion(gen_outputs, label)
            # reduce losses over all GPUs for logging purposes
            outputs = outputs.detach()
            kl_loss_dict = self.kl_criterion(gen_outputs, outputs)

            source_seg_loss_dict_reduced = ptutil.reduce_loss_dict(
                source_seg_loss_dict)
            # print(type(loss_dict_reduced))
            source_seg_losses_reduced = sum(
                loss for loss in source_seg_loss_dict_reduced.values())
            source_seg_loss = sum(loss
                                  for loss in source_seg_loss_dict.values())
            # source_seg_loss.backward()
            gen_seg_loss_dict_reduced = ptutil.reduce_loss_dict(
                gen_seg_loss_dict)
            gen_seg_losses_reduced = sum(
                loss for loss in gen_seg_loss_dict_reduced.values())
            gen_seg_loss = sum(loss for loss in gen_seg_loss_dict.values())
            kl_loss_dict_reduced = ptutil.reduce_loss_dict(kl_loss_dict)
            kl_losses_reduced = sum(loss
                                    for loss in kl_loss_dict_reduced.values())
            kl_loss = sum(loss for loss in kl_loss_dict.values())
            loss_seg = source_seg_loss + gen_seg_loss + kl_loss * 10
            # loss_seg.backward(retain_graph=True)
            # loss = loss_gen + loss_seg
            # loss.backward()
            if config.TRAIN.MIXED_PRECISION:
                with amp.scale_loss(loss_gen, self.gen_optimizer,
                                    loss_id=1) as errGen_scale:
                    errGen_scale.backward()
                with amp.scale_loss(loss_seg, self.optimizer,
                                    loss_id=2) as errSeg_scale:
                    errSeg_scale.backward()
            else:
                loss = loss_gen + loss_seg
                loss.backward()

            if iteration % 8 == 0:
                self.optimizer.step()
                self.gen_optimizer.step()
                self.optimizer.zero_grad()
                self.gen_optimizer.zero_grad()
            trained_time += time.time() - end
            end = time.time()
            if iteration % self.config.TRAIN.LOG_STEP == 0:
                eta_seconds = int(
                    (trained_time / iteration) * (max_iter - iteration))
                log_str = [
                    "Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}".
                    format(iteration, self.optimizer.param_groups[0]['lr'],
                           time.time() - tic,
                           str(datetime.timedelta(seconds=eta_seconds))),
                    "source_seg_loss: {:.6f}, gen_seg_loss:{:.6f}, kl_loss:{:.6f}"
                    .format(source_seg_losses_reduced.item(),
                            gen_seg_losses_reduced.item(),
                            kl_losses_reduced.item() * 10),
                    "feature_loss:{:.6f}, mean_loss:{:.6f}, var_loss:{:.6f}".
                    format(loss_feature_reduced.item(),
                           loss_mean_reduced.item(), loss_var_reduced.item())
                ]
                log_str = ', '.join(log_str)
                self.logger.info(log_str)
                tic = time.time()
            if save_to_disk and iteration % save_iter == 0:
                model_path = os.path.join(
                    self.seg_dir, "{}_{}_{}_iter_{:06d}.pth".format(
                        self.config.MODEL.SEG_NET, self.config.TRAIN.SEG_LOSS,
                        self.config.DATASET.NAME, iteration))
                # self.save_model(model_path)
                ptutil.save_model(self.seg_net, model_path, self.logger)
                generator_path = os.path.join(
                    self.generator_dir, '{}_{}_{}_iter_{:06d}.pth'.format(
                        self.config.MODEL.TARGET_GENERATOR,
                        self.config.TRAIN.SEG_LOSS, self.config.DATASET.NAME,
                        iteration))
                # self.save_model_generator(generator_path)
                ptutil.save_model(self.generator, generator_path, self.logger)
            # Do eval when training, to trace the mAP changes and see performance improved whether or nor
            if self.config.TRAIN.EVAL_EPOCH > 0 and iteration % eval_iter == 0 and not iteration == max_iter:
                metrics = ptutil.validate(self.seg_net, self.valid_loader,
                                          self.metric, self.device,
                                          self.config)
                ptutil.synchronize()
                pixAcc, mIoU = ptutil.accumulate_metric(metrics)
                if mIoU != None and mIoU >= best_miou:
                    best_miou = mIoU
                    model_path = os.path.join(
                        self.seg_dir,
                        "{}_{}_{}_best.pth".format(self.config.MODEL.SEG_NET,
                                                   self.config.TRAIN.SEG_LOSS,
                                                   self.config.DATASET.NAME))
                    ptutil.save_model(self.seg_net, model_path, self.logger)
                    generator_path = os.path.join(
                        self.generator_dir, '{}_{}_{}_best.pth'.format(
                            self.config.TRAIN.TARGET_GENERATOR,
                            self.config.TRAIN.SEG_LOSS,
                            self.config.DATASET.NAME))
                    ptutil.save_model(self.generator, generator_path,
                                      self.logger)
                if pixAcc is not None:
                    self.logger.info('pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                        pixAcc, mIoU))
                self.seg_net.train()
        if save_to_disk:
            model_path = os.path.join(
                self.seg_dir,
                "{}_{}_{}_iter_{:06d}.pth".format(self.config.TRAIN.SEG_NET,
                                                  self.config.TRAIN.SEG_LOSS,
                                                  self.config.DATASET.NAME,
                                                  max_iter))
            ptutil.save_model(self.seg_net, model_path, self.logger)
            generator_path = os.path.join(
                self.generator_dir, '{}_{}_{}_iter_{:06d}.pth'.format(
                    self.config.MODEL.TARGET_GENERATOR,
                    self.config.TRAIN.SEG_LOSS, self.config.DATASET.NAME,
                    max_iter))
            ptutil.save_model(self.generator, generator_path, self.logger)
        # compute training time
        total_training_time = int(time.time() - start_training_time)
        total_time_str = str(datetime.timedelta(seconds=total_training_time))
        self.logger.info("Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / max_iter))
        # eval after training
        if not self.config.TRAIN.SKIP_EVAL:
            metrics = ptutil.validate(self.seg_net, self.valid_loader,
                                      self.metric, self.device, self.config)
            ptutil.synchronize()
            pixAcc, mIoU = ptutil.accumulate_metric(metrics)
            if pixAcc is not None:
                self.logger.info(
                    'After training, pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                        pixAcc, mIoU))