예제 #1
0
def main(args, dst_folder):
    # best_ac only record the best top1_ac for validation set.
    best_ac = 0.0
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    if args.cuda_dev == 1:
        torch.cuda.set_device(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed)  # GPU seed

    random.seed(args.seed)  # python seed for image transformation
    np.random.seed(args.seed)

    if args.dataset == 'svhn':
        mean = [x/255 for x in[127.5,127.5,127.5]]
        std = [x/255 for x in[127.5,127.5,127.5]]
    elif args.dataset == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]

    if args.DA == "standard":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    elif args.DA == "jitter":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
            transforms.RandomCrop(32),
            #SVHNPolicy(),
            #AutoAugment(),
            #transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            #Cutout(n_holes=1,length=20),
            transforms.Normalize(mean, std),
        ])
    else:
        print("Wrong value for --DA argument.")


    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # data loader
    train_loader, test_loader, train_noisy_indexes = data_config(args, transform_train, transform_test,  dst_folder)


    if args.network == "MT_Net":
        print("Loading MT_Net...")
        model = MT_Net(num_classes = args.num_classes, dropRatio = args.dropout).to(device)

    elif args.network == "WRN28_2_wn":
        print("Loading WRN28_2...")
        model = WRN28_2_wn(num_classes = args.num_classes, dropout = args.dropout).to(device)

    elif args.network == "PreactResNet18_WNdrop":
        print("Loading preActResNet18_WNdrop...")
        model = PreactResNet18_WNdrop(drop_val = args.dropout, num_classes = args.num_classes).to(device)


    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    milestones = args.M

    if args.swa == 'True':
        # to install it:
        # pip3 install torchcontrib
        # git clone https://github.com/pytorch/contrib.git
        # cd contrib
        # sudo python3 setup.py install
        from torchcontrib.optim import SWA
        #base_optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

    else:
        #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)



    loss_train_epoch = []
    loss_val_epoch = []
    acc_train_per_epoch = []
    acc_val_per_epoch = []
    new_labels = []


    exp_path = os.path.join('./', 'noise_models_{0}'.format(args.experiment_name), str(args.labeled_samples))
    res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples))

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    cont = 0

    load = False
    save = True

    if args.initial_epoch != 0:
        initial_epoch = args.initial_epoch
        load = True
        save = False

    if args.dataset_type == 'sym_noise_warmUp':
        load = False
        save = True

    if load:
        if args.loss_term == 'Reg_ep':
            train_type = 'C'
        if args.loss_term == 'MixUp_ep':
            train_type = 'M'
        if args.dropout > 0.0:
            train_type = train_type + 'drop' + str(int(10*args.dropout))
        if args.beta == 0.0:
            train_type = train_type + 'noReg'
        path = './checkpoints/warmUp_{6}_{5}_{0}_{1}_{2}_{3}_S{4}.hdf5'.format(initial_epoch, \
                                                                                args.dataset, \
                                                                                args.labeled_samples, \
                                                                                args.network, \
                                                                                args.seed, \
                                                                                args.Mixup_Alpha, \
                                                                                train_type)

        checkpoint = torch.load(path)
        print("Load model in epoch " + str(checkpoint['epoch']))
        print("Path loaded: ", path)
        model.load_state_dict(checkpoint['state_dict'])
        print("Relabeling the unlabeled samples...")
        model.eval()
        initial_rand_relab = args.label_noise
        results = np.zeros((len(train_loader.dataset), 10), dtype=np.float32)

        for images, images_pslab, labels, soft_labels, index in train_loader:

            images = images.to(device)
            labels = labels.to(device)
            soft_labels = soft_labels.to(device)

            outputs = model(images)
            prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args)
            results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist()

        train_loader.dataset.update_labels_randRelab(results, train_noisy_indexes, initial_rand_relab)
        print("Start training...")

    for epoch in range(1, args.epoch + 1):
        st = time.time()
        scheduler.step()
        # train for one epoch
        print(args.experiment_name, args.labeled_samples)

        loss_per_epoch, top_5_train_ac, top1_train_acc_original_labels, \
        top1_train_ac, train_time = train_CrossEntropy_partialRelab(\
                                                        args, model, device, \
                                                        train_loader, optimizer, \
                                                        epoch, train_noisy_indexes)


        loss_train_epoch += [loss_per_epoch]

        # test
        if args.validation_exp == "True":
            loss_per_epoch, acc_val_per_epoch_i = validating(args, model, device, test_loader)
        else:
            loss_per_epoch, acc_val_per_epoch_i = testing(args, model, device, test_loader)

        loss_val_epoch += loss_per_epoch
        acc_train_per_epoch += [top1_train_ac]
        acc_val_per_epoch += acc_val_per_epoch_i



        ####################################################################################################
        #############################               SAVING MODELS                ###########################
        ####################################################################################################

        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        if epoch == 1:
            best_acc_val = acc_val_per_epoch_i[-1]
            snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
        else:
            if acc_val_per_epoch_i[-1] > best_acc_val:
                best_acc_val = acc_val_per_epoch_i[-1]

                if cont > 0:
                    try:
                        os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
                        os.remove(os.path.join(exp_path, snapBest + '.pth'))
                    except OSError:
                        pass
                snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                    epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
                torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
                torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))

        cont += 1

        if epoch == args.epoch:
            snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))


        #### Save models for ensembles:
        if (epoch >= 150) and (epoch%2 == 0) and (args.save_checkpoint == "True"):
            print("Saving model ...")
            out_path = './checkpoints/ENS_{0}_{1}'.format(args.experiment_name, args.labeled_samples)
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            torch.save(model.state_dict(), out_path + "/epoch_{0}.pth".format(epoch))

        ### Saving model to load it again
        # cond = epoch%1 == 0
        if args.dataset_type == 'sym_noise_warmUp':
            if args.loss_term == 'Reg_ep':
                train_type = 'C'
            if args.loss_term == 'MixUp_ep':
                train_type = 'M'
            if args.dropout > 0.0:
                train_type = train_type + 'drop' + str(int(10*args.dropout))
            if args.beta == 0.0:
                train_type = train_type + 'noReg'


            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True
        else:
            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True


        if cond and save:
            print("Saving models...")
            path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, args.labeled_samples, args.network, args.seed)

            save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    'loss_train_epoch' : np.asarray(loss_train_epoch),
                    'loss_val_epoch' : np.asarray(loss_val_epoch),
                    'acc_train_per_epoch' : np.asarray(acc_train_per_epoch),
                    'acc_val_per_epoch' : np.asarray(acc_val_per_epoch),
                    'labels': np.asarray(train_loader.dataset.soft_labels)
                }, filename = path)



        ####################################################################################################
        ############################               SAVING METRICS                ###########################
        ####################################################################################################



        # Save losses:
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch))

        # save accuracies:
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy',
                np.asarray(acc_train_per_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch))

        # save the new labels
        new_labels.append(train_loader.dataset.labels)
        np.save(res_path + '/' + str(args.labeled_samples) + '_new_labels.npy',
                np.asarray(new_labels))

        #logging.info('Epoch: [{}|{}], train_loss: {:.3f}, top1_train_ac: {:.3f}, top1_val_ac: {:.3f}, train_time: {:.3f}'.format(epoch, args.epoch, loss_per_epoch[-1], top1_train_ac, acc_val_per_epoch_i[-1], time.time() - st))

    # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        if args.validation_exp == "True":
            loss_swa, acc_val_swa = validating(args, model, device, test_loader)
        else:
            loss_swa, acc_val_swa = testing(args, model, device, test_loader)

        snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f_swaAcc_%.5f' % (
            epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0])
        torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

    # save_fig(dst_folder)
    print('Best ac:%f' % best_acc_val)
    record_result(dst_folder, best_ac)
예제 #2
0
            if scheduler is not None:
                scheduler.step(total_pesq / num_test_data)

            writer.add_scalar('Loss/valid', total_loss / num_test_data, epoch)
            # writer.add_scalar('PESQ/valid', total_pesq / num_test_data, epoch)

            # checkpointing
            curr_loss = total_loss / num_test_data
            if  curr_loss < best_loss:
                best_loss = curr_loss
                save_path = os.path.join(ckpt_path, 'model_best.ckpt')
                print(f'Saving checkpoint to {save_path}')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': total_loss / num_test_data,
                }, save_path)
            # curr_pesq = total_pesq / num_test_data
            # if  curr_pesq > best_pesq:
                # best_pesq = curr_pesq
                # save_path = os.path.join(ckpt_path, 'model_best.ckpt')
                # print(f'Saving checkpoint to {save_path}')
                # torch.save({
                    # 'epoch': epoch,
                    # 'model_state_dict': net.state_dict(), 
                    # 'optimizer_state_dict': optimizer.state_dict(),
                    # 'loss': total_loss / num_test_data,
                    # 'pesq': total_pesq / num_test_data
                # }, save_path)
예제 #3
0
class Trainer():
    def __init__(self, config_path):
        self.image_config, self.model_config, self.run_config = LoadConfig(
            config_path=config_path).train_config()
        self.device = torch.device('cuda:%d' %
                                   self.run_config['device_ids'][0] if torch.
                                   cuda.is_available else 'cpu')
        self.model = getModel(self.model_config)
        os.makedirs(self.run_config['model_save_path'], exist_ok=True)
        self.run_config['num_workers'] = self.run_config['num_workers'] * len(
            self.run_config['device_ids'])
        self.train_set = Data(root=self.image_config['image_path'],
                              phase='train',
                              data_name=self.image_config['data_name'],
                              img_mode=self.image_config['image_mode'],
                              n_classes=self.model_config['num_classes'],
                              size=self.image_config['image_size'],
                              scale=self.image_config['image_scale'])
        self.valid_set = Data(root=self.image_config['image_path'],
                              phase='valid',
                              data_name=self.image_config['data_name'],
                              img_mode=self.image_config['image_mode'],
                              n_classes=self.model_config['num_classes'],
                              size=self.image_config['image_size'],
                              scale=self.image_config['image_scale'])
        self.className = self.valid_set.className
        self.train_loader = DataLoader(
            self.train_set,
            batch_size=self.run_config['batch_size'],
            shuffle=True,
            num_workers=self.run_config['num_workers'],
            pin_memory=True,
            drop_last=False)
        self.valid_loader = DataLoader(
            self.valid_set,
            batch_size=self.run_config['batch_size'],
            shuffle=True,
            num_workers=self.run_config['num_workers'],
            pin_memory=True,
            drop_last=False)
        train_params = self.model.parameters()
        self.optimizer = RAdam(train_params,
                               lr=eval(self.run_config['lr']),
                               weight_decay=eval(
                                   self.run_config['weight_decay']))
        if self.run_config['swa']:
            self.optimizer = SWA(self.optimizer,
                                 swa_start=10,
                                 swa_freq=5,
                                 swa_lr=0.005)
        # 设置学习率调节策略
        self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer)
        if self.run_config['use_weight_balance']:
            weight = utils.weight_balance.getWeight(
                self.run_config['weights_file'])
        else:
            weight = None
        self.Criterion = SegmentationLosses(weight=weight,
                                            cuda=True,
                                            device=self.device,
                                            batch_average=False)
        self.metric = utils.metrics.MetricMeter(
            self.model_config['num_classes'])

    @logger.catch  # 在日志中记录错误
    def __call__(self):
        # 设置记录日志
        self.global_name = self.model_config['model_name']
        logger.add(os.path.join(
            self.image_config['image_path'], 'log',
            'log_' + self.global_name + '/train_{time}.log'),
                   format="{time} {level} {message}",
                   level="INFO",
                   encoding='utf-8')
        self.writer = SummaryWriter(logdir=os.path.join(
            self.image_config['image_path'], 'run', 'runs_' +
            self.global_name))
        logger.info("image_config: {} \n model_config: {} \n run_config: {}",
                    self.image_config, self.model_config, self.run_config)
        # 如果多余一张卡,就采用数据并行
        if len(self.run_config['device_ids']) > 1:
            self.model = nn.DataParallel(
                self.model, device_ids=self.run_config['device_ids'])
        self.model.to(device=self.device)
        cnt = 0
        # 如果有预训练模型就加载
        if self.run_config['pretrain'] != '':
            logger.info("loading pretrain %s" % self.run_config['pretrain'])
            try:
                self.load_checkpoint(use_optimizer=True,
                                     use_epoch=True,
                                     use_miou=True)
            except:
                print('load model with channed!!!!!')
                self.load_checkpoint_with_changed(use_optimizer=False,
                                                  use_epoch=False,
                                                  use_miou=False)
        logger.info("start training")

        for epoch in range(self.run_config['start_epoch'],
                           self.run_config['epoch']):
            lr = self.optimizer.param_groups[0]['lr']
            print('epoch=%d, lr=%.8f' % (epoch, lr))
            self.train_epoch(epoch, lr)
            valid_miou = self.valid_epoch(epoch)
            # 确定采用哪一种学习率调节策略
            self.lr_scheduler.LambdaLR_(milestone=5,
                                        gamma=0.92).step(epoch=epoch)
            self.save_checkpoint(epoch, valid_miou, 'last_' + self.global_name)
            if valid_miou > self.run_config['best_miou']:
                cnt = 0
                self.save_checkpoint(epoch, valid_miou,
                                     'best_' + self.global_name)
                logger.info("#############   %d saved   ##############" %
                            epoch)
                self.run_config['best_miou'] = valid_miou
            else:
                cnt += 1
                if cnt == self.run_config['early_stop']:
                    logger.info("early stop")
                    break
        self.writer.close()

    def train_epoch(self, epoch, lr):
        self.metric.reset()
        train_loss = 0.0
        train_miou = 0.0
        tbar = tqdm(self.train_loader)
        self.model.train()
        for i, (image, mask, edge) in enumerate(tbar):
            tbar.set_description('train_miou:%.6f' % train_miou)
            tbar.set_postfix({"train_loss": train_loss})
            image = image.to(self.device)
            mask = mask.to(self.device)
            edge = edge.to(self.device)
            self.optimizer.zero_grad()
            out = self.model(image)
            if isinstance(out, tuple):
                aux_out, final_out = out[0], out[1]
            else:
                aux_out, final_out = None, out
            if self.model_config['model_name'] == 'ocrnet':
                aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask)
                cls_loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                mask)
                loss = 0.4 * aux_loss + cls_loss
                loss = loss.mean()
            elif self.model_config['model_name'] == 'hrnet_duc':
                loss_body = self.Criterion.build_loss(
                    mode=self.run_config['loss_type'])(final_out, mask)
                loss_edge = self.Criterion.build_loss(mode='dice')(
                    aux_out.squeeze(), edge)
                loss = loss_body + loss_edge
                loss = loss.mean()
            else:
                loss = self.Criterion.build_loss(
                    mode=self.run_config['loss_type'])(final_out, mask)
            loss.backward()
            self.optimizer.step()
            if self.run_config['swa']:
                self.optimizer.swap_swa_sgd()
            with torch.no_grad():
                train_loss = ((train_loss * i) + loss.item()) / (i + 1)
                _, pred = torch.max(final_out, dim=1)
                self.metric.add(pred.cpu().numpy(), mask.cpu().numpy())
                train_miou, train_ious = self.metric.miou()
                train_fwiou = self.metric.fw_iou()
                train_accu = self.metric.pixel_accuracy()
                train_fwaccu = self.metric.pixel_accuracy_class()
        logger.info(
            "Epoch:%2d\t lr:%.8f\t Train loss:%.4f\t Train FWiou:%.4f\t Train Miou:%.4f\t Train accu:%.4f\t "
            "Train fwaccu:%.4f" % (epoch, lr, train_loss, train_fwiou,
                                   train_miou, train_accu, train_fwaccu))
        cls = ""
        ious = list()
        ious_dict = OrderedDict()
        for i, c in enumerate(self.className):
            ious_dict[c] = train_ious[i]
            ious.append(ious_dict[c])
            cls += "%s:" % c + "%.4f "
        ious = tuple(ious)
        logger.info(cls % ious)
        # tensorboard
        self.writer.add_scalar("lr", lr, epoch)
        self.writer.add_scalar("loss/train_loss", train_loss, epoch)
        self.writer.add_scalar("miou/train_miou", train_miou, epoch)
        self.writer.add_scalar("fwiou/train_fwiou", train_fwiou, epoch)
        self.writer.add_scalar("accuracy/train_accu", train_accu, epoch)
        self.writer.add_scalar("fwaccuracy/train_fwaccu", train_fwaccu, epoch)
        self.writer.add_scalars("ious/train_ious", ious_dict, epoch)

    def valid_epoch(self, epoch):
        self.metric.reset()
        valid_loss = 0.0
        valid_miou = 0.0
        tbar = tqdm(self.valid_loader)
        self.model.eval()
        with torch.no_grad():
            for i, (image, mask, edge) in enumerate(tbar):
                tbar.set_description('valid_miou:%.6f' % valid_miou)
                tbar.set_postfix({"valid_loss": valid_loss})
                image = image.to(self.device)
                mask = mask.to(self.device)
                edge = edge.to(self.device)
                out = self.model(image)
                if isinstance(out, tuple):
                    aux_out, final_out = out[0], out[1]
                else:
                    aux_out, final_out = None, out
                if self.model_config['model_name'] == 'ocrnet':
                    aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out,
                                                                     mask)
                    cls_loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                    mask)
                    loss = 0.4 * aux_loss + cls_loss
                    loss = loss.mean()
                elif self.model_config['model_name'] == 'hrnet_duc':
                    loss_body = self.Criterion.build_loss(
                        mode=self.run_config['loss_type'])(final_out, mask)
                    loss_edge = self.Criterion.build_loss(mode='dice')(
                        aux_out.squeeze(), edge)
                    loss = loss_body + loss_edge
                    # loss = loss.mean()
                else:
                    loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                mask)
                valid_loss = ((valid_loss * i) + float(loss)) / (i + 1)
                _, pred = torch.max(final_out, dim=1)
                self.metric.add(pred.cpu().numpy(), mask.cpu().numpy())
                valid_miou, valid_ious = self.metric.miou()
                valid_fwiou = self.metric.fw_iou()
                valid_accu = self.metric.pixel_accuracy()
                valid_fwaccu = self.metric.pixel_accuracy_class()
            logger.info(
                "epoch:%d\t valid loss:%.4f\t valid fwiou:%.4f\t valid miou:%.4f valid accu:%.4f\t "
                "valid fwaccu:%.4f\t" % (epoch, valid_loss, valid_fwiou,
                                         valid_miou, valid_accu, valid_fwaccu))
            ious = list()
            cls = ""
            ious_dict = OrderedDict()
            for i, c in enumerate(self.className):
                ious_dict[c] = valid_ious[i]
                ious.append(ious_dict[c])
                cls += "%s:" % c + "%.4f "
            ious = tuple(ious)
            logger.info(cls % ious)
            self.writer.add_scalar("loss/valid_loss", valid_loss, epoch)
            self.writer.add_scalar("miou/valid_miou", valid_miou, epoch)
            self.writer.add_scalar("fwiou/valid_fwiou", valid_fwiou, epoch)
            self.writer.add_scalar("accuracy/valid_accu", valid_accu, epoch)
            self.writer.add_scalar("fwaccuracy/valid_fwaccu", valid_fwaccu,
                                   epoch)
            self.writer.add_scalars("ious/valid_ious", ious_dict, epoch)
        return valid_miou

    def save_checkpoint(self, epoch, best_miou, flag):
        meta = {
            'epoch': epoch,
            'model': self.model.state_dict(),
            'optim': self.optimizer.state_dict(),
            'bmiou': best_miou
        }
        try:
            torch.save(meta,
                       os.path.join(self.run_config['model_save_path'],
                                    '%s.pth' % flag),
                       _use_new_zipfile_serialization=False)
        except:
            torch.save(
                meta,
                os.path.join(self.run_config['model_save_path'],
                             '%s.pth' % flag))

    def load_checkpoint(self, use_optimizer, use_epoch, use_miou):
        state_dict = torch.load(self.run_config['pretrain'],
                                map_location=self.device)
        self.model.load_state_dict(state_dict['model'])
        if use_optimizer:
            self.optimizer.load_state_dict(state_dict['optim'])
        if use_epoch:
            self.run_config['start_epoch'] = state_dict['epoch'] + 1
        if use_miou:
            self.run_config['best_miou'] = state_dict['bmiou']

    def load_checkpoint_with_changed(self, use_optimizer, use_epoch, use_miou):
        state_dict = torch.load(self.run_config['pretrain'],
                                map_location=self.device)
        pretrain_dict = state_dict['model']
        model_dict = self.model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and 'edge' not in k
        }
        model_dict.update(pretrain_dict)
        self.model.load_state_dict(model_dict)
        if use_optimizer:
            self.optimizer.load_state_dict(state_dict['optim'])
        if use_epoch:
            self.run_config['start_epoch'] = state_dict['epoch'] + 1
        if use_miou:
            self.run_config['best_miou'] = state_dict['bmiou']
예제 #4
0
class Optimizer:
    optimizer_cls = None
    optimizer = None
    parameters = None

    def __init__(self,
                 gradient_clipping,
                 swa_start=None,
                 swa_freq=None,
                 swa_lr=None,
                 **kwargs):
        self.gradient_clipping = gradient_clipping
        self.optimizer_kwargs = kwargs
        self.swa_start = swa_start
        self.swa_freq = swa_freq
        self.swa_lr = swa_lr

    def set_parameters(self, parameters):
        self.parameters = tuple(parameters)
        self.optimizer = self.optimizer_cls(self.parameters,
                                            **self.optimizer_kwargs)
        if self.swa_start is not None:
            from torchcontrib.optim import SWA
            assert self.swa_freq is not None, self.swa_freq
            assert self.swa_lr is not None, self.swa_lr
            self.optimizer = SWA(self.optimizer,
                                 swa_start=self.swa_start,
                                 swa_freq=self.swa_freq,
                                 swa_lr=self.swa_lr)

    def check_if_set(self):
        assert self.optimizer is not None, \
            'The optimizer is not initialized, call set_parameter before' \
            ' using any of the optimizer functions'

    def zero_grad(self):
        self.check_if_set()
        return self.optimizer.zero_grad()

    def step(self):
        self.check_if_set()
        return self.optimizer.step()

    def swap_swa_sgd(self):
        self.check_if_set()
        from torchcontrib.optim import SWA
        assert isinstance(self.optimizer, SWA), self.optimizer
        return self.optimizer.swap_swa_sgd()

    def clip_grad(self):
        self.check_if_set()
        # Todo: report clipped and unclipped
        # Todo: allow clip=None but still report grad_norm
        grad_clips = self.gradient_clipping
        return torch.nn.utils.clip_grad_norm_(self.parameters, grad_clips)

    def to(self, device):
        if device is None:
            return
        self.check_if_set()
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    def cpu(self):
        return self.to('cpu')

    def cuda(self, device=None):
        assert device is None or isinstance(device, int), device
        if device is None:
            device = torch.device('cuda')
        return self.to(device)

    def load_state_dict(self, state_dict):
        self.check_if_set()
        return self.optimizer.load_state_dict(state_dict)

    def state_dict(self):
        self.check_if_set()
        return self.optimizer.state_dict()
예제 #5
0
    # Call

    print("Starting model training....")

    n_epochs = setting_dict['epochs']
    lr_patience = setting_dict['optimizer']['sheduler']['patience']
    lr_factor = setting_dict['optimizer']['sheduler']['factor']

    if weight_path is None:
        best_epoch = train(model,dataloaders,objective,optimizer,n_epochs,Path_list[1],Path_list[2], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"]))
    else:
        optimizer.load_state_dict(torch.load(weight_path)["optimizer"])
        best_epoch = train(model,dataloaders,objective,optimizer,n_epochs-torch.load(weight_path)["epoch"],Path_list[1],Path_list[2],start_epoch = torch.load(weight_path)["epoch"]+1, loss_dict=torch.load(weight_path)["loss_dict"], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"]))

    print("model training finished! yey!")

    if optimizer_name == "SWA":
        print ("Updating batch norm pars for SWA")
        train_dataset.dataset.SWA = True
        SWA_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=cpu_count)
        optimizer.swap_swa_sgd()
        optimizer.bn_update(SWA_loader, model, device='cuda')
        state = {
                'epoch': n_epochs,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss_dict': {}
                }
        torch.save(state, os.path.join(Path_list[2],'weights_SWA.pt'))
예제 #6
0
    progress["val_loss"].append(np.mean(val_loss))
    progress["val_iou"].append(iou)
    progress["val_dice"].append(dice)
    progress["val_hausdorff"].append(hausdorff)
    progress["val_assd"].append(assd)

    dict2df(progress, args.output_dir + 'progress.csv')

    scheduler_step(optimizer, scheduler, iou, args)

# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #

if args.apply_swa:
    torch.save(optimizer.state_dict(), args.output_dir + "/optimizer_" + args.model_name + "_before_swa_swap.pt")
    optimizer.swap_swa_sgd()  # Set the weights of your model to their SWA averages
    optimizer.bn_update(train_loader, model, device='cuda')

    torch.save(
        model.state_dict(),
        args.output_dir + "/swa_checkpoint_last_bn_update_{}epochs_lr{}.pt".format(args.epochs, args.swa_lr)
    )

    iou, dice, hausdorff, assd, val_loss, stats = val_step(
        val_loader, model, criterion, weights_criterion, multiclass_criterion, args.binary_threshold,
        generate_stats=True, generate_overlays=args.eval_overlays, save_path=os.path.join(args.output_dir, "swa_preds")
    )

    print("[SWA] Val IOU: %s, Val Dice: %s" % (iou, dice))
예제 #7
0
def main(args):
    best_ac = 0.0

    #####################
    # Initializing seeds and preparing GPU
    if args.cuda_dev == 1:
        torch.cuda.set_device(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed)  # GPU seed
    random.seed(args.seed)  # python seed for image transformation
    np.random.seed(args.seed)
    #####################

    if args.dataset == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    elif args.dataset == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
    elif args.dataset == 'miniImagenet':
        mean = [0.4728, 0.4487, 0.4031]
        std = [0.2744, 0.2663 , 0.2806]

    if args.DA == "standard":
        transform_train = transforms.Compose([
            transforms.Pad(6, padding_mode='reflect'),
            transforms.RandomCrop(84),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    elif args.DA == "jitter":
        transform_train = transforms.Compose([
            transforms.Pad(6, padding_mode='reflect'),
            transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
            transforms.RandomCrop(84),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        print("Wrong value for --DA argument.")


    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # data lodaer
    train_loader, test_loader, unlabeled_indexes = data_config(args, transform_train, transform_test)

    if args.network == "TE_Net":
        print("Loading TE_Net...")
        model = TE_Net(num_classes = args.num_classes).to(device)

    elif args.network == "MT_Net":
        print("Loading MT_Net...")
        model = MT_Net(num_classes = args.num_classes).to(device)

    elif args.network == "resnet18":
        print("Loading Resnet18...")
        model = resnet18(num_classes = args.num_classes).to(device)

    elif args.network == "resnet18_wndrop":
        print("Loading Resnet18...")
        model = resnet18_wndrop(num_classes = args.num_classes).to(device)


    print('Total params: {:.2f} M'.format((sum(p.numel() for p in model.parameters()) / 1000000.0)))

    milestones = args.M

    if args.swa == 'True':
        # to install it:
        # pip3 install torchcontrib
        # git clone https://github.com/pytorch/contrib.git
        # cd contrib
        # sudo python3 setup.py install
        from torchcontrib.optim import SWA
        base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    loss_train_epoch = []
    loss_val_epoch = []
    acc_train_per_epoch = []
    acc_val_per_epoch = []
    new_labels = []

    exp_path = os.path.join('./', 'ssl_models_{0}'.format(args.experiment_name), str(args.labeled_samples))
    res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples))

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    cont = 0
    load = False
    save = True

    if args.load_epoch != 0:
        load_epoch = args.load_epoch
        load = True
        save = False

    if args.dataset_type == 'ssl_warmUp':
        load = False
        save = True

    if load:
        if args.loss_term == 'Reg_ep':
            train_type = 'C'
        if args.loss_term == 'MixUp_ep':
            train_type = 'M'
        path = './checkpoints/warmUp_{0}_{1}_{2}_{3}_{4}_{5}_S{6}.hdf5'.format(train_type, \
                                                                                args.Mixup_Alpha, \
                                                                                load_epoch, \
                                                                                args.dataset, \
                                                                                args.labeled_samples, \
                                                                                args.network, \
                                                                                args.seed)

        checkpoint = torch.load(path)
        print("Load model in epoch " + str(checkpoint['epoch']))
        print("Path loaded: ", path)
        model.load_state_dict(checkpoint['state_dict'])
        print("Relabeling the unlabeled samples...")
        model.eval()
        results = np.zeros((len(train_loader.dataset), args.num_classes), dtype=np.float32)
        for images, images_pslab, labels, soft_labels, index in train_loader:

            images = images.to(device)
            labels = labels.to(device)
            soft_labels = soft_labels.to(device)

            outputs = model(images)
            prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args)
            results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist()

        train_loader.dataset.update_labels_randRelab(results, unlabeled_indexes, args.label_noise)
        print("Start training...")

    ####################################################################################################
    ###############################               TRAINING                ##############################
    ####################################################################################################

    for epoch in range(1, args.epoch + 1):
        st = time.time()
        scheduler.step()
        # train for one epoch
        print(args.experiment_name, args.labeled_samples)

        loss_per_epoch_train, \
        top_5_train_ac, \
        top1_train_acc_original_labels,\
        top1_train_ac, \
        train_time = train_CrossEntropy_partialRelab(args, model, device, \
                                        train_loader, optimizer, \
                                        epoch, unlabeled_indexes)



        loss_train_epoch += [loss_per_epoch_train]

        loss_per_epoch_test, acc_val_per_epoch_i = testing(args, model, device, test_loader)

        loss_val_epoch += loss_per_epoch_test
        acc_train_per_epoch += [top1_train_ac]
        acc_val_per_epoch += acc_val_per_epoch_i


        ####################################################################################################
        #############################               SAVING MODELS                ###########################
        ####################################################################################################
        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        if epoch == 1:
            best_acc_val = acc_val_per_epoch_i[-1]
            snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % (
                epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
        else:
            if acc_val_per_epoch_i[-1] > best_acc_val:
                best_acc_val = acc_val_per_epoch_i[-1]

                if cont > 0:
                    try:
                        os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
                        os.remove(os.path.join(exp_path, snapBest + '.pth'))
                    except OSError:
                        pass
                snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % (
                    epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
                torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
                torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))

        cont += 1

        if epoch == args.epoch:
            snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f' % (
                epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

        ### Saving model to load it again
        # cond = epoch%1 == 0
        if args.dataset_type == 'ssl_warmUp':
            if args.loss_term == 'Reg_ep':
                train_type = 'C'
            if args.loss_term == 'MixUp_ep':
                train_type = 'M'

            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True
        else:
            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True

        #print(cond)
        #print(save)
        if cond and save:
            print("Saving models...")
            path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, \
                                                                        args.labeled_samples, \
                                                                        args.network, \
                                                                        args.seed)
            save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    'loss_train_epoch' : np.asarray(loss_train_epoch),
                    'loss_val_epoch' : np.asarray(loss_val_epoch),
                    'acc_train_per_epoch' : np.asarray(acc_train_per_epoch),
                    'acc_val_per_epoch' : np.asarray(acc_val_per_epoch),
                    'labels': np.asarray(train_loader.dataset.soft_labels)

                }, filename = path)

        ####################################################################################################
        ############################               SAVING METRICS                ###########################
        ####################################################################################################

        # Save losses:
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch))

        # save accuracies:
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy',np.asarray(acc_train_per_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch))

    # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        loss_swa, acc_val_swa = testing(args, model, device, test_loader)

        snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f_swaAcc_%.5f' % (
            epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0])
        torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

    print('Best ac:%f' % best_acc_val)
예제 #8
0
def train_model(cfg, run_id, save_dir, use_cuda, args, writer):
    shuffle = True
    print("Run ID : " + args.run_id)
   
    print("Parameters used : ")
    print("batch_size: " + str(args.batch_size))
    print("lr: " + str(args.learning_rate))
    print("loss weights: " + str(params.weights))

    if args.random_skip:
        skip = [x for x in range(0, 4)]
    else:
        skip = [args.skip]
    train_data_gen = Dataset(cfg, args.input_type, 'training', 1.0, args.num_clips, skip, add_background=args.add_background)
    train_dataloader = DataLoader(train_data_gen, batch_size=args.batch_size, shuffle=shuffle, num_workers=args.num_workers, 
                                  collate_fn=lambda b:filter_none(b, args.num_clips, args.varied_length))

    print("Number of training samples : " + str(len(train_data_gen)))
    steps_per_epoch = len(train_data_gen) / args.batch_size
    print("Steps per epoch: " + str(steps_per_epoch))

    if args.add_background:
        num_classes = cfg.num_classes + 1
    else:
        num_classes = cfg.num_classes

    assert args.num_clips > 1
    model = build_model(args.model_version, args.num_clips, num_classes, args.feature_dim, args.hidden_dim, args.num_layers)

    num_gpus = len(args.gpu.split(','))
    if num_gpus > 1:
        model = torch.nn.DataParallel(model)
    
    if use_cuda:
        model.cuda()

    if args.optimizer == 'ADAM':
        print("Using ADAM optimizer")
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        print("Using SGD optimizer")
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = MultiStepLR(optimizer, milestones=[40, 80, 120, 160], gamma=0.5)

    if args.swa_start > 0:
        optimizer = SWA(optimizer)

    criterion = BCEWithLogitsLoss()

    max_fmap_score, fmap_score = 0, 0
    # loop for each epoch
    for epoch in range(args.num_epochs):
        model = train_epoch(cfg, run_id, epoch, train_dataloader, model, num_classes, optimizer, criterion, writer, use_cuda, args, weights=None, accumulation_steps=args.steps)
        if args.dataset in ['charades']:
            validation_interval = 10
            if epoch > 20:
                validation_interval = 5
        else:
            validation_interval = 50
            if epoch > 1000: 
                validation_interval = 10
        if epoch % validation_interval == 0:
            fmap_score = val_epoch(cfg, epoch, model, writer, use_cuda, args)
         
        if fmap_score > max_fmap_score:
            for f in os.listdir(save_dir):
                os.remove(os.path.join(save_dir, f))
            save_file_path = os.path.join(save_dir, 'model_{}_{:.4f}.pth'.format(epoch, fmap_score))
            states = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_file_path)
            max_fmap_score = fmap_score
예제 #9
0
class Trainer(object):
    def __init__(self,
                 args,
                 train_dataloader=None,
                 validate_dataloader=None,
                 test_dataloader=None):
        self.args = args
        self.train_dataloader = train_dataloader
        self.validate_dataloader = validate_dataloader
        self.test_dataloader = test_dataloader

        self.label_lst = [i for i in range(self.args.num_classes)]
        self.num_labels = self.args.num_classes

        self.config_class = AutoConfig
        self.model_class = BertForSequenceClassification

        self.config = self.config_class.from_pretrained(
            self.args.bert_model_name,
            num_labels=self.num_labels,
            finetuning_task='nsmc',
            id2label={str(i): label
                      for i, label in enumerate(self.label_lst)},
            label2id={label: i
                      for i, label in enumerate(self.label_lst)})
        self.model = self.model_class.from_pretrained(
            self.args.bert_model_name, config=self.config)
        self.optimizer = None
        self.scheduler = None

        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available(
        ) and args.cuda else "cpu"
        self.model.to(self.device)

    def train(self, alpha, gamma):
        train_dataloader = self.train_dataloader

        t_total = len(train_dataloader) * self.args.num_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        if self.args.use_swa:
            base_opt = AdamW(optimizer_grouped_parameters,
                             lr=self.args.lr,
                             eps=1e-8)
            self.optimizer = SWA(base_opt,
                                 swa_start=4 * len(train_dataloader),
                                 swa_freq=100,
                                 swa_lr=5e-5)
            self.optimizer.param_groups = self.optimizer.optimizer.param_groups
            self.optimizer.state = self.optimizer.optimizer.state
            self.optimizer.defaults = self.optimizer.optimizer.defaults

        else:
            self.optimizer = optimizer = AdamW(optimizer_grouped_parameters,
                                               lr=self.args.lr,
                                               eps=1e-8)
        self.scheduler = scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=100,
            num_training_steps=self.args.num_epochs * len(train_dataloader))
        self.criterion = FocalLoss(alpha=alpha, gamma=gamma)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d",
                    len(self.train_dataloader) * self.args.batch_size)
        logger.info("  Num Epochs = %d", self.args.num_epochs)
        logger.info("  Total train batch size = %d", self.args.batch_size)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()
        self.optimizer.zero_grad()

        train_iterator = trange(int(self.args.num_epochs), desc="Epoch")

        fin_result = None
        f1_max = 0.0
        self.model.train()

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):

                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3],
                    'token_type_ids': batch[2]
                }

                # outputs = self.model(**inputs)
                # loss = outputs[0]

                # # Custom Loss
                loss, logits = self.model(**inputs)
                logits = torch.sigmoid(logits)

                labels = torch.zeros(
                    (len(batch[3]), self.num_labels)).to(self.device)
                labels[range(len(batch[3])), batch[3]] = 1

                loss = self.criterion(logits, labels)

                loss.backward()
                self.optimizer.step()
                self.scheduler.step()  # Update learning rate schedule

                self.model.zero_grad()
                self.optimizer.zero_grad()

                tr_loss += loss.item()
                global_step += 1
                logger.info('train loss %f', loss.item())

            logger.info('total train loss %f', tr_loss / global_step)
            if epoch >= 4 and self.args.use_swa:
                self.optimizer.swap_swa_sgd()

            fin_result = self.evaluate("validate")
            self.save_model(epoch)
            self.model.train()
            if epoch >= 4 and self.args.use_swa:
                self.optimizer.swap_swa_sgd()

            f1_max = max(fin_result['f1_macro'], f1_max)

        if epoch >= 4 and self.args.use_swa:
            self.optimizer.swap_swa_sgd()
        with open(os.path.join(self.args.base_dir, self.args.result_dir,
                               self.args.train_id, 'param_seach.txt'),
                  "a",
                  encoding="utf-8") as f:
            f.write('alpha: {}, gamma: {}, f1_macro: {}\n'.format(
                alpha, gamma, f1_max))
        return f1_max

    def evaluate(self, mode='test'):
        if mode == 'test':
            dataloader = self.test_dataloader
        elif mode == 'validate':
            dataloader = self.validate_dataloader
        else:
            raise Exception("Only dev and test dataset available")

        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("  Num examples = %d",
                    len(dataloader) * self.args.batch_size)
        logger.info("  Batch size = %d", self.args.batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        self.model.eval()

        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3],
                    'token_type_ids': batch[2]
                }
                outputs = self.model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs['labels'].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    inputs['labels'].detach().cpu().numpy(),
                    axis=0)

        eval_loss = eval_loss / nb_eval_steps
        results = {"loss": eval_loss}

        preds = np.argmax(preds, axis=1)
        result = compute_metrics(preds, out_label_ids)
        results.update(result)

        p_macro, r_macro, f_macro, support_macro \
            = precision_recall_fscore_support(y_true=out_label_ids, y_pred=preds,
                                              labels=[i for i in range(self.num_labels)], average='macro')

        results.update({
            'precision': p_macro,
            'recall': r_macro,
            'f1_macro': f_macro
        })

        with open(self.args.prediction_file, "w", encoding="utf-8") as f:
            for pred in preds:
                f.write("{}\n".format(pred))

        if mode == 'validate':
            logger.info("***** Eval results *****")
            for key in sorted(results.keys()):
                logger.info("  %s = %s", key, str(results[key]))

        return results

    def save_model(self, num=0):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()
        }
        torch.save(
            state,
            os.path.join(self.args.base_dir, self.args.result_dir,
                         self.args.train_id, 'epoch_' + str(num) + '.pth'))
        logger.info('model saved')

    def load_model(self, model_name):

        state = torch.load(
            os.path.join(self.args.base_dir, self.args.result_dir,
                         self.args.train_id, model_name))
        self.model.load_state_dict(state['model'])
        if self.optimizer is not None:
            self.optimizer.load_state_dict(state['optimizer'])
        if self.scheduler is not None:
            self.scheduler.load_state_dict(state['scheduler'])
        logger.info('model loaded')
예제 #10
0
    progress["train_loss"].append(np.mean(train_loss))
    progress["val_loss"].append(np.mean(val_loss))
    progress["val_accuracy"].append(accuracy)

    dict2df(progress, args.output_dir + 'progress.csv')

    scheduler_step(optimizer, scheduler, accuracy, args)

# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #

if args.apply_swa:
    torch.save(
        optimizer.state_dict(), args.output_dir + "/optimizer_" +
        args.model_name + "_before_swa_swap.pt")
    optimizer.swap_swa_sgd(
    )  # Set the weights of your model to their SWA averages
    optimizer.bn_update(train_loader, model, device='cuda')

    torch.save(
        model.state_dict(), args.output_dir +
        "/swa_checkpoint_last_bn_update_{}epochs_lr{}.pt".format(
            args.epochs, args.swa_lr))

    accuracy, val_loss = val_step_accuracy(val_loader,
                                           model,
                                           criterion,
                                           weights_criterion,
                                           multiclass_criterion,
예제 #11
0
def main():

    maxIOU = 0.0
    assert torch.cuda.is_available()
    torch.backends.cudnn.benchmark = True
    model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format(
        'crops')
    focal_loss = FocalLoss2d()
    train_dataset = CropSegmentation(train=True, crop_size=args.crop_size)
    #     test_dataset = CropSegmentation(train=False, crop_size=args.crop_size)

    model = torchvision.models.segmentation.deeplabv3_resnet50(
        pretrained=False, progress=True, num_classes=5, aux_loss=True)

    if args.train:
        weight = np.ones(4)
        weight[2] = 5
        weight[3] = 5
        w = torch.FloatTensor(weight).cuda()
        criterion = nn.CrossEntropyLoss()  #ignore_index=255 weight=w
        model = nn.DataParallel(model).cuda()

        for param in model.parameters():
            param.requires_grad = True

        optimizer1 = optim.SGD(model.parameters(),
                               lr=config.lr,
                               momentum=0.9,
                               weight_decay=1e-4)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=(args.epochs // 9) +
                                                   1)
        optimizer = SWA(optimizer1)

        dataset_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=args.train,
            pin_memory=True,
            num_workers=args.workers)

        max_iter = args.epochs * len(dataset_loader)
        losses = AverageMeter()
        start_epoch = 0

        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {0}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print('=> loaded checkpoint {0} (epoch {1})'.format(
                    args.resume, checkpoint['epoch']))

            else:
                print('=> no checkpoint found at {0}'.format(args.resume))

        for epoch in range(start_epoch, args.epochs):
            scheduler.step(epoch)
            model.train()
            for i, (inputs, target) in enumerate(dataset_loader):

                inputs = Variable(inputs.cuda())
                target = Variable(target.cuda())
                outputs = model(inputs)
                loss1 = focal_loss(outputs['out'], target)
                loss2 = focal_loss(outputs['aux'], target)
                loss01 = loss1 + 0.1 * loss2
                loss3 = lovasz_softmax(outputs['out'], target)
                loss4 = lovasz_softmax(outputs['aux'], target)
                loss02 = loss3 + 0.1 * loss4
                loss = loss01 + loss02
                if np.isnan(loss.item()) or np.isinf(loss.item()):
                    pdb.set_trace()

                losses.update(loss.item(), args.batch_size)

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                if i > 10 and i % 5 == 0:
                    optimizer.update_swa()

                print('epoch: {0}\t'
                      'iter: {1}/{2}\t'
                      'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                          epoch + 1, i + 1, len(dataset_loader), loss=losses))

            if epoch > 5:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, model_fname % (epoch + 1))
        optimizer.swap_swa_sgd()
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_fname % (665 + 1))
예제 #12
0
def main():
    if args.config_path:
        if args.config_path in CONFIG_TREATER:
            load_path = CONFIG_TREATER[args.config_path]
        elif args.config_path.endswith(".yaml"):
            load_path = args.config_path
        else:
            load_path = "experiments/" + CONFIG_TREATER[
                args.config_path] + ".yaml"
        with open(load_path, 'rb') as fp:
            config = CfgNode.load_cfg(fp)
    else:
        config = None

    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    test_model = None
    max_epoch = config.TRAIN.NUM_EPOCHS
    print('data folder: ', args.data_folder)
    torch.backends.cudnn.benchmark = True

    # WORLD_SIZE Generated by torch.distributed.launch.py
    #num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    #is_distributed = num_gpus > 1
    #if is_distributed:
    #    torch.cuda.set_device(args.local_rank)
    #    torch.distributed.init_process_group(
    #        backend="nccl", init_method="env://",
    #    )

    model = get_model(config)
    model_loss = ModelLossWraper(
        model,
        config.TRAIN.CLASS_WEIGHTS,
        config.MODEL.IS_DISASTER_PRED,
        config.MODEL.IS_SPLIT_LOSS,
    ).cuda()

    #if args.local_rank == 0:
    #from IPython import embed; embed()

    #if is_distributed:
    #    model_loss = nn.SyncBatchNorm.convert_sync_batchnorm(model_loss)
    #    model_loss = nn.parallel.DistributedDataParallel(
    #        model_loss#, device_ids=[args.local_rank], output_device=args.local_rank
    #    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        model_loss = nn.DataParallel(model_loss)

    model_loss.to(device)
    cpucount = multiprocessing.cpu_count()

    if config.mode.startswith("single"):
        trainset_loaders = {}
        loader_len = 0
        for disaster in disaster_list[config.mode[6:]]:
            trainset = XView2Dataset(args.data_folder,
                                     rgb_bgr='rgb',
                                     preprocessing={
                                         'flip': True,
                                         'scale': config.TRAIN.MULTI_SCALE,
                                         'crop': config.TRAIN.CROP_SIZE,
                                     },
                                     mode="singletrain",
                                     single_disaster=disaster)
            if len(trainset) > 0:
                train_sampler = None

                trainset_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
                    shuffle=train_sampler is None,
                    pin_memory=True,
                    drop_last=True,
                    sampler=train_sampler,
                    num_workers=cpucount if cpucount < 16 else cpucount // 3)

                trainset_loaders[disaster] = trainset_loader
                loader_len += len(trainset_loader)
                print("added disaster {} with {} samples".format(
                    disaster, len(trainset)))
            else:
                print("skipping disaster ", disaster)

    else:

        trainset = XView2Dataset(args.data_folder,
                                 rgb_bgr='rgb',
                                 preprocessing={
                                     'flip': True,
                                     'scale': config.TRAIN.MULTI_SCALE,
                                     'crop': config.TRAIN.CROP_SIZE,
                                 },
                                 mode=config.mode)

        #if is_distributed:
        #    train_sampler = DistributedSampler(trainset)
        #else:
        train_sampler = None

        trainset_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
            shuffle=train_sampler is None,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler,
            num_workers=multiprocessing.cpu_count())
        loader_len = len(trainset_loader)

    model.train()

    lr_init = config.TRAIN.LR
    optimizer = torch.optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad, model.parameters()),
            'lr': lr_init
        }],
        lr=lr_init,
        momentum=0.9,
        weight_decay=0.,
        nesterov=False,
    )

    num_iters = max_epoch * loader_len

    if config.SWA:
        swa_start = num_iters
        optimizer = SWA(
            optimizer,
            swa_start=swa_start,
            swa_freq=4 * loader_len,
            swa_lr=0.001
        )  #SWA(optimizer, swa_start = None, swa_freq = None, swa_lr = None)#
        #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0001, 0.05, step_size_up=1, step_size_down=2*len(trainset_loader)-1, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
        lr = 0.0001
        #model.load_state_dict(torch.load("ckpt/dual-hrnet/hrnet_450", map_location='cpu')['state_dict'])
        #print("weights loaded")
        max_epoch = max_epoch + 40

    start_epoch = 0
    losses = AverageMeter()
    model.train()
    cur_iters = 0 if start_epoch == 0 else None
    for epoch in range(start_epoch, max_epoch):

        if config.mode.startswith("single"):
            all_batches = []
            total_len = 0
            for disaster in sorted(list(trainset_loaders.keys())):
                all_batches += [
                    (disaster, idx)
                    for idx in range(len(trainset_loaders[disaster]))
                ]
                total_len += len(trainset_loaders[disaster].dataset)
            all_batches = random.sample(all_batches, len(all_batches))
            iterators = {
                disaster: iter(trainset_loaders[disaster])
                for disaster in trainset_loaders.keys()
            }
            if cur_iters is not None:
                cur_iters += len(all_batches)
            else:
                cur_iters = epoch * len(all_batches)

            for i, (disaster, idx) in enumerate(all_batches):
                lr = optimizer.param_groups[0]['lr']
                if not config.SWA or epoch < swa_start:
                    lr = adjust_learning_rate(optimizer, lr_init, num_iters,
                                              i + cur_iters)
                samples = next(iterators[disaster])
                inputs_pre = samples['pre_img'].to(device)
                inputs_post = samples['post_img'].to(device)
                target = samples['mask_img'].to(device)
                #disaster_target = samples['disaster'].to(device)

                loss = model_loss(inputs_pre, inputs_post,
                                  target)  #, disaster_target)

                loss_sum = torch.sum(loss).detach().cpu()
                if np.isnan(loss_sum) or np.isinf(loss_sum):
                    print('check')
                losses.update(loss_sum, 4)  # batch size

                loss = torch.sum(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                if args.local_rank == 0 and i % 10 == 0:
                    logger.info('epoch: {0}\t'
                                'iter: {1}/{2}\t'
                                'lr: {3:.6f}\t'
                                'loss: {loss.val:.4f} ({loss.ema:.4f})\t'
                                'disaster: {dis}'.format(epoch + 1,
                                                         i + 1,
                                                         len(all_batches),
                                                         lr,
                                                         loss=losses,
                                                         dis=disaster))

            del iterators

        else:
            cur_iters = epoch * len(trainset_loader)

            for i, samples in enumerate(trainset_loader):
                lr = optimizer.param_groups[0]['lr']
                if not config.SWA or epoch < swa_start:
                    lr = adjust_learning_rate(optimizer, lr_init, num_iters,
                                              i + cur_iters)

                inputs_pre = samples['pre_img'].to(device)
                inputs_post = samples['post_img'].to(device)
                target = samples['mask_img'].to(device)
                #disaster_target = samples['disaster'].to(device)

                loss = model_loss(inputs_pre, inputs_post,
                                  target)  #, disaster_target)

                loss_sum = torch.sum(loss).detach().cpu()
                if np.isnan(loss_sum) or np.isinf(loss_sum):
                    print('check')
                losses.update(loss_sum, 4)  # batch size

                loss = torch.sum(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                #if args.swa == "True":
                #scheduler.step()
                #if epoch%4 == 3 and i == len(trainset_loader)-2:
                #    optimizer.update_swa()

                if args.local_rank == 0 and i % 10 == 0:
                    logger.info('epoch: {0}\t'
                                'iter: {1}/{2}\t'
                                'lr: {3:.6f}\t'
                                'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                                    epoch + 1,
                                    i + 1,
                                    len(trainset_loader),
                                    lr,
                                    loss=losses))

        if args.local_rank == 0:
            if (epoch + 1) % 50 == 0 and test_model is None:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, os.path.join(ckpts_save_dir, 'hrnet_%s' % (epoch + 1)))
    if config.SWA:
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("preSWA")))
        optimizer.swap_swa_sgd()
        bn_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=2,
            shuffle=train_sampler is None,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler,
            num_workers=multiprocessing.cpu_count())
        bn_update(bn_loader, model, device='cuda')
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("SWA")))