def run_main(self):
        """"""
        """Fix the random seed"""
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        """RandAugment parameters"""
        if self.flag_randaug == 1:
            if self.rand_n == 0 and self.rand_m == 0:
                if self.n_model == 'ResNet':
                    self.rand_n = 2
                    self.rand_m = 9
                elif self.n_model == 'WideResNet':
                    if self.n_data == 'CIFAR-10':
                        self.rand_n = 3
                        self.rand_m = 5
                    elif self.n_data == 'CIFAR-100':
                        self.rand_n = 2
                        self.rand_m = 14
                    elif self.n_data == 'SVHN':
                        self.rand_n = 3
                        self.rand_m = 7
        """Dataset"""
        traintest_dataset = dataset.MyDataset_training(
            n_data=self.n_data,
            num_data=self.num_training_data,
            seed=self.seed,
            flag_randaug=self.flag_randaug,
            rand_n=self.rand_n,
            rand_m=self.rand_m,
            cutout=self.cutout)
        self.num_channel, self.num_classes, self.size_after_cnn, self.input_size, self.hidden_size = traintest_dataset.get_info(
            n_data=self.n_data)

        n_samples = len(traintest_dataset)
        if self.num_training_data == 0:
            self.num_training_data = n_samples

        if self.flag_traintest == 1:
            # train_size = self.num_classes * 100
            train_size = int(n_samples * 0.65)
            test_size = n_samples - train_size
            train_dataset, test_dataset = torch.utils.data.random_split(
                traintest_dataset, [train_size, test_size])

            train_sampler = None
            test_sampler = None
        else:
            train_dataset = traintest_dataset
            test_dataset = dataset.MyDataset_test(n_data=self.n_data)

            train_sampler = train_dataset.sampler
            test_sampler = test_dataset.sampler

        num_workers = 16
        train_shuffle = True
        test_shuffle = False

        if train_sampler:
            train_shuffle = False
        if test_sampler:
            test_shuffle = False
        self.train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size_training,
            sampler=train_sampler,
            shuffle=train_shuffle,
            num_workers=num_workers,
            pin_memory=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=test_dataset,
            batch_size=self.batch_size_test,
            sampler=test_sampler,
            shuffle=test_shuffle,
            num_workers=num_workers,
            pin_memory=True)
        """Transfer learning"""
        if self.flag_transfer == 1:
            pretrained = True
            num_classes = 1000
        else:
            pretrained = False
            num_classes = self.num_classes
        """Neural network model"""
        model = None
        if self.n_model == 'CNN':
            model = cnn.ConvNet(num_classes=self.num_classes,
                                num_channel=self.num_channel,
                                size_after_cnn=self.size_after_cnn,
                                n_aug=self.n_aug)
        elif self.n_model == 'ResNet':
            model = resnet.ResNet(n_data=self.n_data,
                                  depth=50,
                                  num_classes=self.num_classes,
                                  num_channel=self.num_channel,
                                  n_aug=self.n_aug,
                                  bottleneck=True)  # resnet50
            # model = resnet.ResNet(n_data=self.n_data, depth=200, num_classes=self.num_classes, num_channel=self.num_channel, bottleneck=True)  # resnet200
        elif self.n_model == 'WideResNet':
            # model = resnet.wide_resnet50_2(num_classes=self.num_classes, num_channel=self.num_channel)
            # model = WideResNet(depth=40, iden_factor=2, dropout_rate=0.0, num_classes=num_class, num_channel=self.num_channel)  # wresnet40_2
            model = wideresnet.WideResNet(depth=28,
                                          widen_factor=10,
                                          dropout_rate=0.0,
                                          num_classes=self.num_classes,
                                          num_channel=self.num_channel,
                                          n_aug=self.n_aug)  # wresnet28_10
        print(torch.__version__)
        """Transfer learning"""
        if self.flag_transfer == 1:
            for param in model.parameters():
                param.requires_grad = False

            num_features = model.fc.in_features
            model.fc = nn.Linear(num_features, self.num_classes)
        """Show paramters"""
        if self.show_params == 1:
            params = 0
            for p in model.parameters():
                if p.requires_grad:
                    params += p.numel()
            print(params)
        """GPU setting"""
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = model.to(device)
        if device == 'cuda':
            if self.gpu_multi == 1:
                model = torch.nn.DataParallel(model)
            torch.backends.cudnn.benchmark = True
            print('GPU={}'.format(torch.cuda.device_count()))
        """Loss function"""
        if self.lb_smooth > 0.0:
            criterion = objective.SmoothCrossEntropyLoss(self.lb_smooth)
        else:
            criterion = objective.SoftCrossEntropy()
        """Optimizer"""
        optimizer = 0

        if self.opt == 0:  # Adam
            if self.flag_transfer == 1:
                optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
            else:
                optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        elif self.opt == 1:  # SGD
            if self.n_model == 'ResNet':
                lr = 0.1
                weight_decay = 0.0001
            elif self.n_model == 'WideResNet':
                if self.n_data == 'SVHN':
                    lr = 0.005
                    weight_decay = 0.001
                else:
                    lr = 0.1
                    weight_decay = 0.0005
            else:
                lr = 0.1
                weight_decay = 0.0005

            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=lr,
                                        momentum=0.9,
                                        weight_decay=weight_decay,
                                        nesterov=True)
        if self.flag_lars == 1:
            from torchlars import LARS
            optimizer = LARS(optimizer)
        """Learning rate scheduling"""
        scheduler = None
        if self.flag_lr_schedule == 2:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.num_epochs, eta_min=0.)
        elif self.flag_lr_schedule == 3:
            if self.num_epochs == 90:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer, [30, 60, 80])
            elif self.num_epochs == 180:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer, [60, 120, 160])
            elif self.num_epochs == 270:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer, [90, 180, 240])

        if self.flag_warmup == 1:
            if self.n_model == 'ResNet':
                multiplier = 2
                total_epoch = 3
            elif self.n_model == 'WideResNet':
                multiplier = 2
                if self.n_data == 'SVHN':
                    total_epoch = 3
                else:
                    total_epoch = 5
            else:
                multiplier = 2
                total_epoch = 3

            scheduler = GradualWarmupScheduler(optimizer,
                                               multiplier=multiplier,
                                               total_epoch=total_epoch,
                                               after_scheduler=scheduler)
        """Initialize"""
        self.flag_noise = np.random.randint(0, 2, self.num_training_data)

        if self.flag_acc5 == 1:
            results = np.zeros((self.num_epochs, 6))
        else:
            results = np.zeros((self.num_epochs, 5))
        start_time = timeit.default_timer()

        t = 0
        # fixed_interval = 10
        fixed_interval = 1
        loss_fixed_all = np.zeros(self.num_epochs // fixed_interval)
        self.loss_training_batch = np.zeros(
            int(self.num_epochs *
                np.ceil(self.num_training_data / self.batch_size_training)))

        for epoch in range(self.num_epochs):
            """Training"""
            model.train()
            start_epoch_time = timeit.default_timer(
            )  # Get the start time of this epoch

            loss_each_all = np.zeros(self.num_training_data)
            loss_training_all = 0
            loss_test_all = 0
            """Learning rate scheduling"""
            # if self.flag_lr_schedule == 1:
            #     if self.num_epochs == 200:
            #         if epoch == 100:
            #             optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

            if self.flag_variance == 1:  # when computing the variance of training loss
                if epoch % fixed_interval == 0:
                    loss_fixed = np.zeros(self.num_training_data //
                                          self.batch_size_training)

                    for i, (images, labels,
                            index) in enumerate(self.train_loader):
                        if np.array(images.data.cpu()).ndim == 3:
                            images = images.reshape(images.shape[0], 1,
                                                    images.shape[1],
                                                    images.shape[2]).to(device)
                        else:
                            images = images.to(device)
                        labels = labels.to(device)

                        flag_onehot = 0
                        if self.flag_spa == 1:
                            outputs_fixed = model.forward(images)
                            if flag_onehot == 0:
                                labels = np.identity(self.num_classes)[labels]
                        else:
                            outputs_fixed = model(images)
                            labels = np.identity(self.num_classes)[np.array(
                                labels.data.cpu())]
                        labels = util.to_var(torch.from_numpy(labels).float())

                        loss_fixed[i] = criterion.forward(
                            outputs_fixed, labels)
                    loss_fixed_all[t] = np.var(loss_fixed)
                    t += 1

            total_steps = len(self.train_loader)
            steps = 0
            num_training_data = 0
            for i, (images, labels, index) in enumerate(self.train_loader):
                steps += 1
                if np.array(images.data.cpu()).ndim == 3:
                    images = images.reshape(images.shape[0], 1,
                                            images.shape[1],
                                            images.shape[2]).to(device)
                else:
                    images = images.to(device)
                labels = labels.to(device)
                """Save images"""
                if self.save_images == 1:
                    util.save_images(images)
                """Get training loss for each sample by inputting data before applying data augmentation"""
                if self.flag_spa == 1:
                    outputs = model(images)
                    labels_spa = labels.clone()

                    if labels_spa.ndim == 1:
                        labels_spa = torch.eye(
                            self.num_classes,
                            device='cuda')[labels_spa].clone()  # To one-hot

                    loss_each = criterion.forward_each_example(
                        outputs, labels_spa)  # Loss for each sample
                    loss_each_all[index] = np.array(
                        loss_each.data.cpu())  # Put losses for target samples

                    self.flag_noise = util.flag_update(
                        loss_each_all,
                        self.judge_noise)  # Update the noise flag
                """Forward propagation"""
                if self.flag_spa == 1:
                    images, labels = util.self_paced_augmentation(
                        images=images,
                        labels=labels,
                        flag_noise=self.flag_noise,
                        index=np.array(index.data.cpu()),
                        n_aug=self.n_aug,
                        num_classes=self.num_classes)
                else:
                    images, labels = util.run_n_aug(
                        x=images,
                        y=labels,
                        n_aug=self.n_aug,
                        num_classes=self.num_classes)

                outputs = model(images)

                if labels.ndim == 1:
                    labels = torch.eye(self.num_classes,
                                       device='cuda')[labels].clone()

                loss_training = criterion.forward(outputs, labels)
                loss_training_all += loss_training.item() * outputs.shape[0]
                # self.loss_training_batch[int(i + epoch * np.ceil(self.num_training_data / self.batch_size_training))] = loss_training * outputs.shape[0]
                num_training_data += images.shape[0]
                """Back propagation and update"""
                optimizer.zero_grad()
                loss_training.backward()
                optimizer.step()
                """When changing flag_noise randomly"""
                """
                if self.flag_spa == 1:
                    outputs = model.forward(x=images)

                    if labels.ndim == 1:
                        y_soft = torch.eye(self.num_classes, device='cuda')[labels]  # Convert to one-hot
                    loss_each = criterion.forward_each_example(outputs, y_soft)  # Loss for each sample
                    loss_each_all[index] = np.array(loss_each.data.cpu())  # Put losses for target samples
    
                    self.flag_noise = util.flag_update(loss_each_all, self.judge_noise)  # Update the noise flag
                """

            loss_training_each = loss_training_all / num_training_data
            # np.random.shuffle(self.flag_noise)
            """Test"""
            model.eval()

            with torch.no_grad():
                if self.flag_acc5 == 1:
                    top1 = list()
                    top5 = list()
                else:
                    correct = 0
                    total = 0

                num_test_data = 0
                for images, labels in self.test_loader:
                    if np.array(images.data).ndim == 3:
                        images = images.reshape(images.shape[0], 1,
                                                images.shape[1],
                                                images.shape[2]).to(device)
                    else:
                        images = images.to(device)
                    labels = labels.to(device)

                    outputs = model(x=images)

                    if self.flag_acc5 == 1:
                        acc1, acc5 = util.accuracy(outputs.data,
                                                   labels.long(),
                                                   topk=(1, 5))
                        top1.append(acc1[0].item())
                        top5.append(acc5[0].item())
                    else:
                        _, predicted = torch.max(outputs.data, 1)
                        correct += (predicted == labels.long()).sum().item()
                        total += labels.size(0)

                    if labels.ndim == 1:
                        labels = torch.eye(self.num_classes,
                                           device='cuda')[labels]

                    loss_test = criterion.forward(outputs, labels)
                    loss_test_all += loss_test.item() * outputs.shape[0]
                    num_test_data += images.shape[0]
            """Compute test results"""
            top1_avg = 0
            top5_avg = 0
            test_accuracy = 0

            if self.flag_acc5 == 1:
                top1_avg = sum(top1) / float(len(top1))
                top5_avg = sum(top5) / float(len(top5))
            else:
                test_accuracy = 100.0 * correct / total

            loss_test_each = loss_test_all / num_test_data
            """Compute running time"""
            end_epoch_time = timeit.default_timer()
            epoch_time = end_epoch_time - start_epoch_time
            num_flag = np.sum(self.flag_noise == 1)
            """Learning rate scheduling"""
            if self.flag_lr_schedule > 1 and scheduler is not None:
                scheduler.step(epoch - 1 + float(steps) / total_steps)
            """Show results"""
            flag_log = 1
            if flag_log == 1:
                if self.flag_acc5 == 1:
                    print(
                        'Epoch [{}/{}], Train Loss: {:.4f}, Top1 Test Acc: {:.3f} %, Top5 Test Acc: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s, Num_flag: {}'
                        .format(epoch + 1, self.num_epochs, loss_training_each,
                                top1_avg, top5_avg, loss_test_each, epoch_time,
                                num_flag))
                else:
                    print(
                        'Epoch [{}/{}], Train Loss: {:.4f}, Test Acc: {:.3f} %, Test Loss: {:.4f}, Epoch Time: {:.2f}s, Num_flag: {}'
                        .format(epoch + 1, self.num_epochs, loss_training_each,
                                test_accuracy, loss_test_each, epoch_time,
                                num_flag))

            if self.flag_wandb == 1:
                wandb.log({"loss_training_each": loss_training_each})
                wandb.log({"test_accuracy": test_accuracy})
                wandb.log({"loss_test_each": loss_test_each})
                wandb.log({"num_flag": num_flag})
                wandb.log({"epoch_time": epoch_time})

            if self.save_file == 1:
                if flag_log == 1:
                    if self.flag_acc5 == 1:
                        results[epoch][0] = loss_training_each
                        results[epoch][1] = top1_avg
                        results[epoch][2] = top5_avg
                        results[epoch][3] = loss_test_each
                        results[epoch][4] = num_flag
                        results[epoch][5] = epoch_time
                    else:
                        results[epoch][0] = loss_training_each
                        results[epoch][1] = test_accuracy
                        results[epoch][2] = loss_test_each
                        results[epoch][3] = num_flag
                        results[epoch][4] = epoch_time

        end_time = timeit.default_timer()

        flag_log = 1
        if flag_log == 1:
            print(' ran for %.4fm' % ((end_time - start_time) / 60.))
        """Show accuracy"""
        top1_avg_max = np.max(results[:, 1])
        print(top1_avg_max)

        if flag_log == 1 and self.flag_acc5 == 1:
            top5_avg_max = np.max(results[:, 2])
            print(top5_avg_max)
        """Save files"""
        if self.save_file == 1:
            if self.flag_randaug == 1:
                np.savetxt(
                    'results/data_%s_model_%s_num_%s_randaug_%s_n_%s_m_%s_seed_%s_acc_%s.csv'
                    % (self.n_data, self.n_model, self.num_training_data,
                       self.flag_randaug, self.rand_n, self.rand_m, self.seed,
                       top1_avg_max),
                    results,
                    delimiter=',')
            else:
                if self.flag_spa == 1:
                    np.savetxt(
                        'results/data_%s_model_%s_judge_%s_aug_%s_num_%s_seed_%s_acc_%s.csv'
                        % (self.n_data, self.n_model, self.judge_noise,
                           self.n_aug, self.num_training_data, self.seed,
                           top1_avg_max),
                        results,
                        delimiter=',')
                else:
                    np.savetxt(
                        'results/data_%s_model_%s_aug_%s_num_%s_seed_%s_acc_%s.csv'
                        % (self.n_data, self.n_model, self.n_aug,
                           self.num_training_data, self.seed, top1_avg_max),
                        results,
                        delimiter=',')

            if self.flag_variance == 1:
                np.savetxt('results/loss_variance_judge_%s_aug_%s_acc_%s.csv' %
                           (self.judge_noise, self.n_aug, top1_avg_max),
                           loss_fixed_all,
                           delimiter=',')
Example #2
0
    def train(self,
              dataloader,
              temperature,
              ckpt_path,
              n_epochs=90,
              save_size=10):

        # trainers
        criterion = NTCrossEntropyLoss(temperature, self.batch_size,
                                       self.device).to(self.device)
        optimizer = LARS(torch.optim.SGD(self.model.parameters(), lr=4))
        # optimizer = optimizer.to(self.device)

        losses = []

        for epoch in range(n_epochs):
            with tqdm(total=len(dataloader)) as progress:
                running_loss = 0
                i = 0
                for (xis, xjs), _ in dataloader:
                    i += 1
                    optimizer.zero_grad()

                    xis = xis.to(self.device)
                    xjs = xjs.to(self.device)

                    # Get representations and projections
                    his, zis = self.model(xis)
                    hjs, zjs = self.model(xjs)

                    # normalize
                    zis = F.normalize(zis, dim=1)
                    zjs = F.normalize(zjs, dim=1)

                    loss = criterion(zis, zjs)
                    running_loss += loss.item()

                    # optimize
                    loss.backward()
                    optimizer.step()

                    # update tqdm
                    progress.set_description('train loss:{:.4f}'.format(
                        loss.item()))
                    progress.update()

                    # record loss
                    if i % save_size == (save_size - 1):
                        losses.append(running_loss / save_size)
                        running_loss = 0

                # save model
                if epoch % 10 == 0:
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': losses,
                        }, ckpt_path + f'{epoch}')

        return self.return_model(), losses
Example #3
0
def train(cfg, writer, logger):

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path
    logger.info("data path: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        norm=cfg.data.norm,
        split='train',
        split_root=cfg.data.split,
        augments=data_aug,
        logger=logger,
        log=cfg.data.log,
        ENL=cfg.data.ENL,
    )

    v_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        split='val',
        log=cfg.data.log,
        split_root=cfg.data.split,
        logger=logger,
        ENL=cfg.data.ENL,
    )

    train_data_len = len(t_loader)
    logger.info(
        f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}'
    )

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size,
                                  num_workers=cfg.train.n_workers,
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg.test.batch_size,
        # persis
        num_workers=cfg.train.n_workers,
    )

    # Setup Model
    device = f'cuda:{cfg.train.gpu[0]}'
    model = get_model(cfg.model).to(device)
    input_size = (cfg.model.in_channels, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in vars(cfg.train.optimizer).items()
        if k not in ('name', 'wrap')
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer,
               'wrap') and cfg.train.optimizer.wrap == 'lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    # loss_fn = get_loss_function(cfg)
    # logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg.train.resume))

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg.train.resume, checkpoint["epoch"]))

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file
                        or '_last_model' in file):
                    # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(
                        osp.join(resume_src_dir, file),
                        resume_dst_dir,
                    )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    data_range = 255
    if cfg.data.log:
        data_range = np.log(data_range)
    # data_range /= 350

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    train_loss_meter = averageMeter()
    val_psnr_meter = averageMeter()
    val_ssim_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time()
    train_val_start_time = time.time()
    model.train()
    while it < train_iter:
        for clean, noisy, _ in trainloader:
            it += 1

            noisy = noisy.to(device, dtype=torch.float32)
            # noisy /= 350
            mask1, mask2 = rand_pool.generate_mask_pair(noisy)
            noisy_sub1 = rand_pool.generate_subimages(noisy, mask1)
            noisy_sub2 = rand_pool.generate_subimages(noisy, mask2)

            # preparing for the regularization term
            with torch.no_grad():
                noisy_denoised = model(noisy)
            noisy_sub1_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask1)
            noisy_sub2_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask2)
            # print(rand_pool.operation_seed_counter)

            # for ii, param in enumerate(model.parameters()):
            #     if torch.sum(torch.isnan(param.data)):
            #         print(f'{ii}: nan parameters')

            # calculating the loss
            noisy_output = model(noisy_sub1)
            noisy_target = noisy_sub2
            if cfg.train.loss.gamma.const:
                gamma = cfg.train.loss.gamma.base
            else:
                gamma = it / train_iter * cfg.train.loss.gamma.base

            diff = noisy_output - noisy_target
            exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
            loss1 = torch.mean(diff**2)
            loss2 = gamma * torch.mean((diff - exp_diff)**2)
            loss_all = loss1 + loss2

            # loss1 = noisy_output - noisy_target
            # loss2 = torch.exp(noisy_target - noisy_output)
            # loss_all = torch.mean(loss1 + loss2)
            loss_all.backward()

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            # record the loss of the minibatch
            train_loss_meter.update(loss_all)
            train_time_meter.update(time.time() - train_start_time)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it)

            if it % 1000 == 0:
                writer.add_histogram('hist/pred', noisy_denoised, it)
                writer.add_histogram('hist/noisy', noisy, it)

                if cfg.data.simulate:
                    writer.add_histogram('hist/clean', clean, it)

            if cfg.data.simulate:
                pass

            # print interval
            if it % cfg.train.print_interval == 0:
                terminal_info = f"Iter [{it:d}/{train_iter:d}]  \
                                train Loss: {train_loss_meter.avg:.4f}  \
                                Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}"

                logger.info(terminal_info)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg, it)

                if cfg.data.simulate:
                    pass

                runing_metrics_train.reset()
                train_time_meter.reset()
                train_loss_meter.reset()

            # val interval
            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()
                with torch.no_grad():
                    for clean, noisy, _ in valloader:
                        # noisy /= 350
                        # clean /= 350
                        noisy = noisy.to(device, dtype=torch.float32)
                        noisy_denoised = model(noisy)

                        if cfg.data.simulate:
                            clean = clean.to(device, dtype=torch.float32)
                            psnr = piq.psnr(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            ssim = piq.ssim(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            val_psnr_meter.update(psnr)
                            val_ssim_meter.update(ssim)

                        val_loss = torch.mean((noisy_denoised - noisy)**2)
                        val_loss_meter.update(val_loss)

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(
                    f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}"
                )
                val_loss_meter.reset()
                running_metrics_val.reset()

                if cfg.data.simulate:
                    writer.add_scalars('metrics/val', {
                        'psnr': val_psnr_meter.avg,
                        'ssim': val_ssim_meter.avg
                    }, it)
                    logger.info(
                        f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}'
                    )
                    val_psnr_meter.reset()
                    val_ssim_meter.reset()

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter - it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            # save model
            if it % (train_iter / cfg.train.epoch * 10) == 0:
                ep = int(it / (train_iter / cfg.train.epoch))
                state = {
                    "epoch": it,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = osp.join(writer.file_writer.get_logdir(),
                                     f"{ep}.pkl")
                torch.save(state, save_path)
                logger.info(f'saved model state dict at {save_path}')

            train_start_time = time.time()
Example #4
0
def train(cfg, writer, logger):
    
    # Setup random seeds to a determinated value for reproduction
    # seed = 1337
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    # np.random.default_rng(seed)

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.train_split,
        norm = cfg.data.norm,
        augments=data_aug
        )

    v_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.val_split,
        )
    train_data_len = len(t_loader)
    logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}')

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size, 
                                  num_workers=cfg.train.n_workers, 
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=10, 
                                # persis
                                num_workers=cfg.train.n_workers,)

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model, 2).to(device)
    input_size = (cfg.model.input_nbr, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)      #自动多卡运行,这个好用
    
    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items()
                        if k not in ('name', 'wrap')}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    loss_fn = get_loss_function(cfg)
    logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume)
            )

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg.train.resume, checkpoint["epoch"]
                )
            )

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file or '_last_model' in file):
                # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time() 
    train_val_start_time = time.time()
    model.train()   
    while it < train_iter:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1           
            file_a = file_a.to(device)            
            file_b = file_b.to(device)            
            label = label.to(device)            
            mask = mask.to(device)

            optimizer.zero_grad()
            # print(f'dtype: {file_a.dtype}')
            outputs = model(file_a, file_b)
            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()

            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()
            
            # record the acc of the minibatch
            pred = outputs.max(1)[1].cpu().numpy()
            runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy())

            train_time_meter.update(time.time() - train_start_time)

            if it % cfg.train.print_interval == 0:
                # acc of the samples between print_interval
                score, _ = runing_metrics_train.get_scores()
                train_cls_0_acc, train_cls_1_acc = score['Acc']
                fmt_str = "Iter [{:d}/{:d}]  train Loss: {:.4f}  Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}"
                print_str = fmt_str.format(it,
                                           train_iter,
                                           loss.item(),      #extracts the loss’s value as a Python float.
                                           train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc)
                runing_metrics_train.reset()
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it)
                writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it)
                # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it)
                # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it)

            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()            # change behavior like drop out
                with torch.no_grad():   # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val, mask_val) in valloader:      
                        file_a_val = file_a_val.to(device)            
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max() returns the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy())
            
                        label_val = label_val.to(device)            
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                score, _ = running_metrics_val.get_scores()
                val_cls_0_acc, val_cls_1_acc = score['Acc']

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}")
                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)

                logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc))
                writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it)
                # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it)
                # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it)

                val_loss_meter.reset()
                running_metrics_val.reset()

                # OA=score["Overall_Acc"]
                val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2
                if val_macro_OA >= best_macro_OA_now and it>200:
                    best_macro_OA_now = val_macro_OA
                    best_macro_OA_iter_now = it
                    state = {
                        "epoch": it,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_macro_OA_now": best_macro_OA_now,
                        'best_macro_OA_iter_now':best_macro_OA_iter_now,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader))
                    torch.save(state, save_path)

                    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
                    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter-it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            train_start_time = time.time() 

    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

    state = {
            "epoch": it,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_macro_OA_now": best_macro_OA_now,
            'best_macro_OA_iter_now':best_macro_OA_iter_now,
            }
    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader))
    torch.save(state, save_path)
Example #5
0
            progress.reset(t_ds)
            progress.reset(t_test)
            for images, labels in dl:
                images = images.to(config.device).expand(-1, 3, -1, -1)
                labels = labels.to(config.device)

                classes = eff_net(images)
                loss = cross_entropy(classes, labels)
                losses.append(loss.item())
                progress.update(t_ds,
                                advance=1,
                                description=f'[magenta] {mean(losses):.5f}')

                optim.zero_grad()
                loss.backward()
                optim.step()

            correct = 0
            total = 0
            for images, labels in test:
                images = images.to(config.device).expand(-1, 3, -1, -1)
                labels = labels.to(config.device)
                classes = eff_net(images)
                classes = torch.argmax(classes, dim=1)
                for label, cls in zip(classes, labels):
                    if label == cls:
                        correct += 1
                    total += 1
                progress.update(t_test,
                                advance=1,
                                description=f'[blue]{correct}/{total}')