예제 #1
0
def main(args, logger):
    # trn_df = pd.read_csv(f'{MNT_DIR}/inputs/origin/train.csv')
    trn_df = pd.read_pickle(f'{MNT_DIR}/inputs/nes_info/trn_df.pkl')
    trn_df['is_original'] = 1

    gkf = GroupKFold(n_splits=5).split(
        X=trn_df.question_body,
        groups=trn_df.question_body_le,
    )

    histories = {
        'trn_loss': {},
        'val_loss': {},
        'val_metric': {},
        'val_metric_raws': {},
    }
    loaded_fold = -1
    loaded_epoch = -1
    if args.checkpoint:
        histories, loaded_fold, loaded_epoch = load_checkpoint(args.checkpoint)

    fold_best_metrics = []
    fold_best_metrics_raws = []
    for fold, (trn_idx, val_idx) in enumerate(gkf):
        if fold < loaded_fold:
            fold_best_metrics.append(np.max(histories["val_metric"][fold]))
            fold_best_metrics_raws.append(
                histories["val_metric_raws"][fold][np.argmax(
                    histories["val_metric"][fold])])
            continue
        sel_log(
            f' --------------------------- start fold {fold} --------------------------- ',
            logger)
        fold_trn_df = trn_df.iloc[trn_idx]  # .query('is_original == 1')
        fold_trn_df = fold_trn_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        # use only original row
        fold_val_df = trn_df.iloc[val_idx].query('is_original == 1')
        fold_val_df = fold_val_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        if args.debug:
            fold_trn_df = fold_trn_df.sample(100, random_state=71)
            fold_val_df = fold_val_df.sample(100, random_state=71)
        temp = pd.Series(
            list(
                itertools.chain.from_iterable(
                    fold_trn_df.question_title.apply(lambda x: x.split(' ')) +
                    fold_trn_df.question_body.apply(lambda x: x.split(' ')) +
                    fold_trn_df.answer.apply(lambda x: x.split(' '))))
        ).value_counts()
        tokens = temp[temp >= 10].index.tolist()
        # tokens = []
        tokens = [
            'CAT_TECHNOLOGY'.casefold(),
            'CAT_STACKOVERFLOW'.casefold(),
            'CAT_CULTURE'.casefold(),
            'CAT_SCIENCE'.casefold(),
            'CAT_LIFE_ARTS'.casefold(),
        ]

        trn_dataset = QUESTDataset(
            df=fold_trn_df,
            mode='train',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=True,
            LABEL_COL=LABEL_COL,
            t_max_len=30,
            q_max_len=239 * 2,
            a_max_len=239 * 0,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
            rm_zero=RM_ZERO,
        )
        # update token
        trn_sampler = RandomSampler(data_source=trn_dataset)
        trn_loader = DataLoader(trn_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=trn_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=True,
                                pin_memory=True)
        val_dataset = QUESTDataset(
            df=fold_val_df,
            mode='valid',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=True,
            LABEL_COL=LABEL_COL,
            t_max_len=30,
            q_max_len=239 * 2,
            a_max_len=239 * 0,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
            rm_zero=RM_ZERO,
        )
        val_sampler = RandomSampler(data_source=val_dataset)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=val_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=False,
                                pin_memory=True)

        fobj = BCEWithLogitsLoss()
        state_dict = BertModel.from_pretrained(MODEL_PRETRAIN).state_dict()
        model = BertModelForBinaryMultiLabelClassifier(
            num_labels=len(LABEL_COL),
            config_path=MODEL_CONFIG_PATH,
            state_dict=state_dict,
            token_size=len(trn_dataset.tokenizer),
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
        )
        # optimizer = optim.Adam(model.parameters(), lr=3e-5)
        optimizer = optim.SGD(model.parameters(), lr=1e-1)
        optimizer = SWA(optimizer, swa_start=2, swa_freq=5, swa_lr=1e-1)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=MAX_EPOCH,
                                                         eta_min=1e-2)

        # load checkpoint model, optim, scheduler
        if args.checkpoint and fold == loaded_fold:
            load_checkpoint(args.checkpoint, model, optimizer, scheduler)

        for epoch in tqdm(list(range(MAX_EPOCH))):
            if fold <= loaded_fold and epoch <= loaded_epoch:
                continue
            if epoch < 1:
                model.freeze_unfreeze_bert(freeze=True, logger=logger)
            else:
                model.freeze_unfreeze_bert(freeze=False, logger=logger)
            model = DataParallel(model)
            model = model.to(DEVICE)
            trn_loss = train_one_epoch(model, fobj, optimizer, trn_loader,
                                       DEVICE)
            if epoch > 2:
                optimizer.swap_swa_sgd()
                optimizer.bn_update(trn_loader, model)
            val_loss, val_metric, val_metric_raws, val_y_preds, val_y_trues, val_qa_ids = test(
                model, fobj, val_loader, DEVICE, mode='valid')
            if epoch > 2:
                optimizer.swap_swa_sgd()

            scheduler.step()
            if fold in histories['trn_loss']:
                histories['trn_loss'][fold].append(trn_loss)
            else:
                histories['trn_loss'][fold] = [
                    trn_loss,
                ]
            if fold in histories['val_loss']:
                histories['val_loss'][fold].append(val_loss)
            else:
                histories['val_loss'][fold] = [
                    val_loss,
                ]
            if fold in histories['val_metric']:
                histories['val_metric'][fold].append(val_metric)
            else:
                histories['val_metric'][fold] = [
                    val_metric,
                ]
            if fold in histories['val_metric_raws']:
                histories['val_metric_raws'][fold].append(val_metric_raws)
            else:
                histories['val_metric_raws'][fold] = [
                    val_metric_raws,
                ]

            logging_val_metric_raws = ''
            for val_metric_raw in val_metric_raws:
                logging_val_metric_raws += f'{float(val_metric_raw):.4f}, '

            sel_log(
                f'fold : {fold} -- epoch : {epoch} -- '
                f'trn_loss : {float(trn_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_loss : {float(val_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_metric : {float(val_metric):.4f} -- '
                f'val_metric_raws : {logging_val_metric_raws}', logger)
            model = model.to('cpu')
            model = model.module
            save_checkpoint(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}', model,
                            optimizer, scheduler, histories, val_y_preds,
                            val_y_trues, val_qa_ids, fold, epoch, val_loss,
                            val_metric)
        fold_best_metrics.append(np.max(histories["val_metric"][fold]))
        fold_best_metrics_raws.append(
            histories["val_metric_raws"][fold][np.argmax(
                histories["val_metric"][fold])])
        save_and_clean_for_prediction(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}',
                                      trn_dataset.tokenizer,
                                      clean=False)
        del model

    # calc training stats
    fold_best_metric_mean = np.mean(fold_best_metrics)
    fold_best_metric_std = np.std(fold_best_metrics)
    fold_stats = f'{EXP_ID} : {fold_best_metric_mean:.4f} +- {fold_best_metric_std:.4f}'
    sel_log(fold_stats, logger)
    send_line_notification(fold_stats)

    fold_best_metrics_raws_mean = np.mean(fold_best_metrics_raws, axis=0)
    fold_raw_stats = ''
    for metric_stats_raw in fold_best_metrics_raws_mean:
        fold_raw_stats += f'{float(metric_stats_raw):.4f},'
    sel_log(fold_raw_stats, logger)
    send_line_notification(fold_raw_stats)

    sel_log('now saving best checkpoints...', logger)
예제 #2
0
def train(model_name, optim='adam'):
    train_dataset = PretrainDataset(output_shape=config['image_resolution'])
    train_loader = DataLoader(train_dataset,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)

    val_dataset = IDRND_dataset_CV(fold=0,
                                   mode=config['mode'].replace('train', 'val'),
                                   double_loss_mode=True,
                                   output_shape=config['image_resolution'])
    val_loader = DataLoader(val_dataset,
                            batch_size=config['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=False)

    if model_name == 'EF':
        model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained(
            'efficientnet-b3')).to(device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.0090592697255896_1.0.pth"
            ))
    elif model_name == 'EFGAP':
        model = DoubleLossModelTwoHead(
            base_model=EfficientNetGAP.from_pretrained('efficientnet-b3')).to(
                device)
        model.load_state_dict(
            torch.load(
                f"../models_weights/pretrained/{model_name}_{4}_2.3281182915644134_1.0.pth"
            ))

    criterion = FocalLoss(add_weight=False).to(device)
    criterion4class = CrossEntropyLoss().to(device)

    if optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config['learning_rate'],
                                     weight_decay=config['weight_decay'])
    elif optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=False)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    momentum=0.9,
                                    lr=config['learning_rate'],
                                    weight_decay=config['weight_decay'],
                                    nesterov=True)

    steps_per_epoch = train_loader.__len__() - 15
    swa = SWA(optimizer,
              swa_start=config['swa_start'] * steps_per_epoch,
              swa_freq=int(config['swa_freq'] * steps_per_epoch),
              swa_lr=config['learning_rate'] / 10)
    scheduler = ExponentialLR(swa, gamma=0.9)
    # scheduler = StepLR(swa, step_size=5*steps_per_epoch, gamma=0.5)

    global_step = 0
    for epoch in trange(10):
        if epoch < 5:
            scheduler.step()
            continue
        model.train()
        train_bar = tqdm(train_loader)
        train_bar.set_description_str(desc=f"N epochs - {epoch}")

        for step, batch in enumerate(train_bar):
            global_step += 1
            image = batch['image'].to(device)
            label4class = batch['label0'].to(device)
            label = batch['label1'].to(device)

            output4class, output = model(image)
            loss4class = criterion4class(output4class, label4class)
            loss = criterion(output.squeeze(), label)
            swa.zero_grad()
            total_loss = loss4class * 0.5 + loss * 0.5
            total_loss.backward()
            swa.step()
            train_writer.add_scalar(tag="learning_rate",
                                    scalar_value=scheduler.get_lr()[0],
                                    global_step=global_step)
            train_writer.add_scalar(tag="BinaryLoss",
                                    scalar_value=loss.item(),
                                    global_step=global_step)
            train_writer.add_scalar(tag="SoftMaxLoss",
                                    scalar_value=loss4class.item(),
                                    global_step=global_step)
            train_bar.set_postfix_str(f"Loss = {loss.item()}")
            try:
                train_writer.add_scalar(tag="idrnd_score",
                                        scalar_value=idrnd_score_pytorch(
                                            label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="far_score",
                                        scalar_value=far_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="frr_score",
                                        scalar_value=frr_score(label, output),
                                        global_step=global_step)
                train_writer.add_scalar(tag="accuracy",
                                        scalar_value=bce_accuracy(
                                            label, output),
                                        global_step=global_step)
            except Exception:
                pass

        if (epoch > config['swa_start']
                and epoch % 2 == 0) or (epoch == config['number_epochs'] - 1):
            swa.swap_swa_sgd()
            swa.bn_update(train_loader, model, device)
            swa.swap_swa_sgd()

        scheduler.step()
        evaluate(model, val_loader, epoch, model_name)
예제 #3
0
def main(args, dst_folder):
    # best_ac only record the best top1_ac for validation set.
    best_ac = 0.0
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    if args.cuda_dev == 1:
        torch.cuda.set_device(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed)  # GPU seed

    random.seed(args.seed)  # python seed for image transformation
    np.random.seed(args.seed)

    if args.dataset == 'svhn':
        mean = [x/255 for x in[127.5,127.5,127.5]]
        std = [x/255 for x in[127.5,127.5,127.5]]
    elif args.dataset == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]

    if args.DA == "standard":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    elif args.DA == "jitter":
        transform_train = transforms.Compose([
            transforms.Pad(2, padding_mode='reflect'),
            transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
            transforms.RandomCrop(32),
            #SVHNPolicy(),
            #AutoAugment(),
            #transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            #Cutout(n_holes=1,length=20),
            transforms.Normalize(mean, std),
        ])
    else:
        print("Wrong value for --DA argument.")


    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # data loader
    train_loader, test_loader, train_noisy_indexes = data_config(args, transform_train, transform_test,  dst_folder)


    if args.network == "MT_Net":
        print("Loading MT_Net...")
        model = MT_Net(num_classes = args.num_classes, dropRatio = args.dropout).to(device)

    elif args.network == "WRN28_2_wn":
        print("Loading WRN28_2...")
        model = WRN28_2_wn(num_classes = args.num_classes, dropout = args.dropout).to(device)

    elif args.network == "PreactResNet18_WNdrop":
        print("Loading preActResNet18_WNdrop...")
        model = PreactResNet18_WNdrop(drop_val = args.dropout, num_classes = args.num_classes).to(device)


    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    milestones = args.M

    if args.swa == 'True':
        # to install it:
        # pip3 install torchcontrib
        # git clone https://github.com/pytorch/contrib.git
        # cd contrib
        # sudo python3 setup.py install
        from torchcontrib.optim import SWA
        #base_optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

    else:
        #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)



    loss_train_epoch = []
    loss_val_epoch = []
    acc_train_per_epoch = []
    acc_val_per_epoch = []
    new_labels = []


    exp_path = os.path.join('./', 'noise_models_{0}'.format(args.experiment_name), str(args.labeled_samples))
    res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples))

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    cont = 0

    load = False
    save = True

    if args.initial_epoch != 0:
        initial_epoch = args.initial_epoch
        load = True
        save = False

    if args.dataset_type == 'sym_noise_warmUp':
        load = False
        save = True

    if load:
        if args.loss_term == 'Reg_ep':
            train_type = 'C'
        if args.loss_term == 'MixUp_ep':
            train_type = 'M'
        if args.dropout > 0.0:
            train_type = train_type + 'drop' + str(int(10*args.dropout))
        if args.beta == 0.0:
            train_type = train_type + 'noReg'
        path = './checkpoints/warmUp_{6}_{5}_{0}_{1}_{2}_{3}_S{4}.hdf5'.format(initial_epoch, \
                                                                                args.dataset, \
                                                                                args.labeled_samples, \
                                                                                args.network, \
                                                                                args.seed, \
                                                                                args.Mixup_Alpha, \
                                                                                train_type)

        checkpoint = torch.load(path)
        print("Load model in epoch " + str(checkpoint['epoch']))
        print("Path loaded: ", path)
        model.load_state_dict(checkpoint['state_dict'])
        print("Relabeling the unlabeled samples...")
        model.eval()
        initial_rand_relab = args.label_noise
        results = np.zeros((len(train_loader.dataset), 10), dtype=np.float32)

        for images, images_pslab, labels, soft_labels, index in train_loader:

            images = images.to(device)
            labels = labels.to(device)
            soft_labels = soft_labels.to(device)

            outputs = model(images)
            prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args)
            results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist()

        train_loader.dataset.update_labels_randRelab(results, train_noisy_indexes, initial_rand_relab)
        print("Start training...")

    for epoch in range(1, args.epoch + 1):
        st = time.time()
        scheduler.step()
        # train for one epoch
        print(args.experiment_name, args.labeled_samples)

        loss_per_epoch, top_5_train_ac, top1_train_acc_original_labels, \
        top1_train_ac, train_time = train_CrossEntropy_partialRelab(\
                                                        args, model, device, \
                                                        train_loader, optimizer, \
                                                        epoch, train_noisy_indexes)


        loss_train_epoch += [loss_per_epoch]

        # test
        if args.validation_exp == "True":
            loss_per_epoch, acc_val_per_epoch_i = validating(args, model, device, test_loader)
        else:
            loss_per_epoch, acc_val_per_epoch_i = testing(args, model, device, test_loader)

        loss_val_epoch += loss_per_epoch
        acc_train_per_epoch += [top1_train_ac]
        acc_val_per_epoch += acc_val_per_epoch_i



        ####################################################################################################
        #############################               SAVING MODELS                ###########################
        ####################################################################################################

        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        if epoch == 1:
            best_acc_val = acc_val_per_epoch_i[-1]
            snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
        else:
            if acc_val_per_epoch_i[-1] > best_acc_val:
                best_acc_val = acc_val_per_epoch_i[-1]

                if cont > 0:
                    try:
                        os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
                        os.remove(os.path.join(exp_path, snapBest + '.pth'))
                    except OSError:
                        pass
                snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % (
                    epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
                torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
                torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))

        cont += 1

        if epoch == args.epoch:
            snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f' % (
                epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))


        #### Save models for ensembles:
        if (epoch >= 150) and (epoch%2 == 0) and (args.save_checkpoint == "True"):
            print("Saving model ...")
            out_path = './checkpoints/ENS_{0}_{1}'.format(args.experiment_name, args.labeled_samples)
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            torch.save(model.state_dict(), out_path + "/epoch_{0}.pth".format(epoch))

        ### Saving model to load it again
        # cond = epoch%1 == 0
        if args.dataset_type == 'sym_noise_warmUp':
            if args.loss_term == 'Reg_ep':
                train_type = 'C'
            if args.loss_term == 'MixUp_ep':
                train_type = 'M'
            if args.dropout > 0.0:
                train_type = train_type + 'drop' + str(int(10*args.dropout))
            if args.beta == 0.0:
                train_type = train_type + 'noReg'


            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True
        else:
            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True


        if cond and save:
            print("Saving models...")
            path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, args.labeled_samples, args.network, args.seed)

            save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    'loss_train_epoch' : np.asarray(loss_train_epoch),
                    'loss_val_epoch' : np.asarray(loss_val_epoch),
                    'acc_train_per_epoch' : np.asarray(acc_train_per_epoch),
                    'acc_val_per_epoch' : np.asarray(acc_val_per_epoch),
                    'labels': np.asarray(train_loader.dataset.soft_labels)
                }, filename = path)



        ####################################################################################################
        ############################               SAVING METRICS                ###########################
        ####################################################################################################



        # Save losses:
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch))

        # save accuracies:
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy',
                np.asarray(acc_train_per_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch))

        # save the new labels
        new_labels.append(train_loader.dataset.labels)
        np.save(res_path + '/' + str(args.labeled_samples) + '_new_labels.npy',
                np.asarray(new_labels))

        #logging.info('Epoch: [{}|{}], train_loss: {:.3f}, top1_train_ac: {:.3f}, top1_val_ac: {:.3f}, train_time: {:.3f}'.format(epoch, args.epoch, loss_per_epoch[-1], top1_train_ac, acc_val_per_epoch_i[-1], time.time() - st))

    # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        if args.validation_exp == "True":
            loss_swa, acc_val_swa = validating(args, model, device, test_loader)
        else:
            loss_swa, acc_val_swa = testing(args, model, device, test_loader)

        snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f_swaAcc_%.5f' % (
            epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0])
        torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

    # save_fig(dst_folder)
    print('Best ac:%f' % best_acc_val)
    record_result(dst_folder, best_ac)
예제 #4
0
    # Call

    print("Starting model training....")

    n_epochs = setting_dict['epochs']
    lr_patience = setting_dict['optimizer']['sheduler']['patience']
    lr_factor = setting_dict['optimizer']['sheduler']['factor']

    if weight_path is None:
        best_epoch = train(model,dataloaders,objective,optimizer,n_epochs,Path_list[1],Path_list[2], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"]))
    else:
        optimizer.load_state_dict(torch.load(weight_path)["optimizer"])
        best_epoch = train(model,dataloaders,objective,optimizer,n_epochs-torch.load(weight_path)["epoch"],Path_list[1],Path_list[2],start_epoch = torch.load(weight_path)["epoch"]+1, loss_dict=torch.load(weight_path)["loss_dict"], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"]))

    print("model training finished! yey!")

    if optimizer_name == "SWA":
        print ("Updating batch norm pars for SWA")
        train_dataset.dataset.SWA = True
        SWA_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=cpu_count)
        optimizer.swap_swa_sgd()
        optimizer.bn_update(SWA_loader, model, device='cuda')
        state = {
                'epoch': n_epochs,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss_dict': {}
                }
        torch.save(state, os.path.join(Path_list[2],'weights_SWA.pt'))
예제 #5
0
    def train(self):
        # prepare data
        train_data = self.data('train')
        train_steps = int((len(train_data) + self.config.batch_size - 1) /
                          self.config.batch_size)
        train_dataloader = DataLoader(train_data,
                                      batch_size=self.config.batch_size,
                                      collate_fn=self.get_collate_fn('train'),
                                      shuffle=True,
                                      num_workers=2)

        # prepare optimizer
        params_lr = [{
            "params": self.model.bert_parameters,
            'lr': self.config.small_lr
        }, {
            "params": self.model.other_parameters,
            'lr': self.config.large_lr
        }]
        optimizer = torch.optim.Adam(params_lr)
        optimizer = SWA(optimizer)

        # prepare early stopping
        early_stopping = EarlyStopping(self.model,
                                       self.config.best_model_path,
                                       big_server=BIG_GPU,
                                       mode='max',
                                       patience=10,
                                       verbose=True)

        # prepare learning schedual
        learning_schedual = LearningSchedual(
            optimizer, self.config.epochs, train_steps,
            [self.config.small_lr, self.config.large_lr])

        # prepare other
        aux = REModelAux(self.config, train_steps)
        moving_log = MovingData(window=500)

        ending_flag = False
        # self.model.load_state_dict(torch.load(ROOT_SAVED_MODEL + 'temp_model.ckpt'))
        #
        # with torch.no_grad():
        #     self.model.eval()
        #     print(self.eval())
        #     return
        for epoch in range(0, self.config.epochs):
            for step, (inputs, y_trues,
                       spo_info) in enumerate(train_dataloader):
                inputs = [aaa.cuda() for aaa in inputs]
                y_trues = [aaa.cuda() for aaa in y_trues]
                if epoch > 0 or step == 1000:
                    self.model.detach_bert = False
                # train ================================================================================================
                preds = self.model(inputs)
                loss = self.calculate_loss(preds, y_trues, inputs[1],
                                           inputs[2])
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(self.model.parameters(), 1)
                optimizer.step()

                with torch.no_grad():

                    logs = {'lr0': 0, 'lr1': 0}
                    if (epoch > 0 or step > 620) and False:
                        sbj_f1, spo_f1 = self.calculate_train_f1(
                            spo_info[0], preds, spo_info[1:3],
                            inputs[2].cpu().numpy())
                        metrics_data = {
                            'loss': loss.cpu().numpy(),
                            'sampled_num': 1,
                            'sbj_correct_num': sbj_f1[0],
                            'sbj_pred_num': sbj_f1[1],
                            'sbj_true_num': sbj_f1[2],
                            'spo_correct_num': spo_f1[0],
                            'spo_pred_num': spo_f1[1],
                            'spo_true_num': spo_f1[2]
                        }
                        moving_data = moving_log(epoch * train_steps + step,
                                                 metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data[
                            'sampled_num']
                        logs['sbj_precise'], logs['sbj_recall'], logs[
                            'sbj_f1'] = calculate_f1(
                                moving_data['sbj_correct_num'],
                                moving_data['sbj_pred_num'],
                                moving_data['sbj_true_num'],
                                verbose=True)
                        logs['spo_precise'], logs['spo_recall'], logs[
                            'spo_f1'] = calculate_f1(
                                moving_data['spo_correct_num'],
                                moving_data['spo_pred_num'],
                                moving_data['spo_true_num'],
                                verbose=True)
                    else:
                        metrics_data = {
                            'loss': loss.cpu().numpy(),
                            'sampled_num': 1
                        }
                        moving_data = moving_log(epoch * train_steps + step,
                                                 metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data[
                            'sampled_num']

                    # update lr
                    logs['lr0'], logs['lr1'] = learning_schedual.update_lr(
                        epoch, step)

                    if step == int(train_steps / 2) or step + 1 == train_steps:
                        self.model.eval()
                        torch.save(self.model.state_dict(),
                                   ROOT_SAVED_MODEL + 'temp_model.ckpt')
                        aux.new_line()
                        # dev ==========================================================================================
                        dev_result = self.eval()
                        logs['dev_loss'] = dev_result['loss']
                        logs['dev_sbj_precise'] = dev_result['sbj_precise']
                        logs['dev_sbj_recall'] = dev_result['sbj_recall']
                        logs['dev_sbj_f1'] = dev_result['sbj_f1']
                        logs['dev_spo_precise'] = dev_result['spo_precise']
                        logs['dev_spo_recall'] = dev_result['spo_recall']
                        logs['dev_spo_f1'] = dev_result['spo_f1']
                        logs['dev_precise'] = dev_result['precise']
                        logs['dev_recall'] = dev_result['recall']
                        logs['dev_f1'] = dev_result['f1']

                        # other thing
                        early_stopping(logs['dev_f1'])
                        if logs['dev_f1'] > 0.730:
                            optimizer.update_swa()

                        # test =========================================================================================
                        if (epoch + 1 == self.config.epochs and step + 1
                                == train_steps) or early_stopping.early_stop:
                            ending_flag = True
                            optimizer.swap_swa_sgd()
                            optimizer.bn_update(train_dataloader, self.model)
                            torch.save(self.model.state_dict(),
                                       ROOT_SAVED_MODEL + 'swa.ckpt')
                            self.test(ROOT_SAVED_MODEL + 'swa.ckpt')

                        self.model.train()
                aux.show_log(epoch, step, logs)
                if ending_flag:
                    return
예제 #6
0
    progress["val_dice"].append(dice)
    progress["val_hausdorff"].append(hausdorff)
    progress["val_assd"].append(assd)

    dict2df(progress, args.output_dir + 'progress.csv')

    scheduler_step(optimizer, scheduler, iou, args)

# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #
# --------------------------------------------------------------------------------------------------------------- #

if args.apply_swa:
    torch.save(optimizer.state_dict(), args.output_dir + "/optimizer_" + args.model_name + "_before_swa_swap.pt")
    optimizer.swap_swa_sgd()  # Set the weights of your model to their SWA averages
    optimizer.bn_update(train_loader, model, device='cuda')

    torch.save(
        model.state_dict(),
        args.output_dir + "/swa_checkpoint_last_bn_update_{}epochs_lr{}.pt".format(args.epochs, args.swa_lr)
    )

    iou, dice, hausdorff, assd, val_loss, stats = val_step(
        val_loader, model, criterion, weights_criterion, multiclass_criterion, args.binary_threshold,
        generate_stats=True, generate_overlays=args.eval_overlays, save_path=os.path.join(args.output_dir, "swa_preds")
    )

    print("[SWA] Val IOU: %s, Val Dice: %s" % (iou, dice))

print("\n---------------")
val_iou = np.array(progress["val_iou"])
예제 #7
0
def main(args):
    best_ac = 0.0

    #####################
    # Initializing seeds and preparing GPU
    if args.cuda_dev == 1:
        torch.cuda.set_device(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed)  # GPU seed
    random.seed(args.seed)  # python seed for image transformation
    np.random.seed(args.seed)
    #####################

    if args.dataset == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    elif args.dataset == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
    elif args.dataset == 'miniImagenet':
        mean = [0.4728, 0.4487, 0.4031]
        std = [0.2744, 0.2663 , 0.2806]

    if args.DA == "standard":
        transform_train = transforms.Compose([
            transforms.Pad(6, padding_mode='reflect'),
            transforms.RandomCrop(84),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

    elif args.DA == "jitter":
        transform_train = transforms.Compose([
            transforms.Pad(6, padding_mode='reflect'),
            transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
            transforms.RandomCrop(84),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        print("Wrong value for --DA argument.")


    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # data lodaer
    train_loader, test_loader, unlabeled_indexes = data_config(args, transform_train, transform_test)

    if args.network == "TE_Net":
        print("Loading TE_Net...")
        model = TE_Net(num_classes = args.num_classes).to(device)

    elif args.network == "MT_Net":
        print("Loading MT_Net...")
        model = MT_Net(num_classes = args.num_classes).to(device)

    elif args.network == "resnet18":
        print("Loading Resnet18...")
        model = resnet18(num_classes = args.num_classes).to(device)

    elif args.network == "resnet18_wndrop":
        print("Loading Resnet18...")
        model = resnet18_wndrop(num_classes = args.num_classes).to(device)


    print('Total params: {:.2f} M'.format((sum(p.numel() for p in model.parameters()) / 1000000.0)))

    milestones = args.M

    if args.swa == 'True':
        # to install it:
        # pip3 install torchcontrib
        # git clone https://github.com/pytorch/contrib.git
        # cd contrib
        # sudo python3 setup.py install
        from torchcontrib.optim import SWA
        base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
        optimizer = SWA(base_optimizer, swa_lr=args.swa_lr)

    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    loss_train_epoch = []
    loss_val_epoch = []
    acc_train_per_epoch = []
    acc_val_per_epoch = []
    new_labels = []

    exp_path = os.path.join('./', 'ssl_models_{0}'.format(args.experiment_name), str(args.labeled_samples))
    res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples))

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    cont = 0
    load = False
    save = True

    if args.load_epoch != 0:
        load_epoch = args.load_epoch
        load = True
        save = False

    if args.dataset_type == 'ssl_warmUp':
        load = False
        save = True

    if load:
        if args.loss_term == 'Reg_ep':
            train_type = 'C'
        if args.loss_term == 'MixUp_ep':
            train_type = 'M'
        path = './checkpoints/warmUp_{0}_{1}_{2}_{3}_{4}_{5}_S{6}.hdf5'.format(train_type, \
                                                                                args.Mixup_Alpha, \
                                                                                load_epoch, \
                                                                                args.dataset, \
                                                                                args.labeled_samples, \
                                                                                args.network, \
                                                                                args.seed)

        checkpoint = torch.load(path)
        print("Load model in epoch " + str(checkpoint['epoch']))
        print("Path loaded: ", path)
        model.load_state_dict(checkpoint['state_dict'])
        print("Relabeling the unlabeled samples...")
        model.eval()
        results = np.zeros((len(train_loader.dataset), args.num_classes), dtype=np.float32)
        for images, images_pslab, labels, soft_labels, index in train_loader:

            images = images.to(device)
            labels = labels.to(device)
            soft_labels = soft_labels.to(device)

            outputs = model(images)
            prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args)
            results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist()

        train_loader.dataset.update_labels_randRelab(results, unlabeled_indexes, args.label_noise)
        print("Start training...")

    ####################################################################################################
    ###############################               TRAINING                ##############################
    ####################################################################################################

    for epoch in range(1, args.epoch + 1):
        st = time.time()
        scheduler.step()
        # train for one epoch
        print(args.experiment_name, args.labeled_samples)

        loss_per_epoch_train, \
        top_5_train_ac, \
        top1_train_acc_original_labels,\
        top1_train_ac, \
        train_time = train_CrossEntropy_partialRelab(args, model, device, \
                                        train_loader, optimizer, \
                                        epoch, unlabeled_indexes)



        loss_train_epoch += [loss_per_epoch_train]

        loss_per_epoch_test, acc_val_per_epoch_i = testing(args, model, device, test_loader)

        loss_val_epoch += loss_per_epoch_test
        acc_train_per_epoch += [top1_train_ac]
        acc_val_per_epoch += acc_val_per_epoch_i


        ####################################################################################################
        #############################               SAVING MODELS                ###########################
        ####################################################################################################
        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        if epoch == 1:
            best_acc_val = acc_val_per_epoch_i[-1]
            snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % (
                epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
        else:
            if acc_val_per_epoch_i[-1] > best_acc_val:
                best_acc_val = acc_val_per_epoch_i[-1]

                if cont > 0:
                    try:
                        os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth'))
                        os.remove(os.path.join(exp_path, snapBest + '.pth'))
                    except OSError:
                        pass
                snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % (
                    epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
                torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth'))
                torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth'))

        cont += 1

        if epoch == args.epoch:
            snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f' % (
                epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val)
            torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
            torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

        ### Saving model to load it again
        # cond = epoch%1 == 0
        if args.dataset_type == 'ssl_warmUp':
            if args.loss_term == 'Reg_ep':
                train_type = 'C'
            if args.loss_term == 'MixUp_ep':
                train_type = 'M'

            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True
        else:
            cond = (epoch==args.epoch)
            name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type)
            save = True

        #print(cond)
        #print(save)
        if cond and save:
            print("Saving models...")
            path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, \
                                                                        args.labeled_samples, \
                                                                        args.network, \
                                                                        args.seed)
            save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                    'loss_train_epoch' : np.asarray(loss_train_epoch),
                    'loss_val_epoch' : np.asarray(loss_val_epoch),
                    'acc_train_per_epoch' : np.asarray(acc_train_per_epoch),
                    'acc_val_per_epoch' : np.asarray(acc_val_per_epoch),
                    'labels': np.asarray(train_loader.dataset.soft_labels)

                }, filename = path)

        ####################################################################################################
        ############################               SAVING METRICS                ###########################
        ####################################################################################################

        # Save losses:
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch))

        # save accuracies:
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy',np.asarray(acc_train_per_epoch))
        np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch))

    # applying swa
    if args.swa == 'True':
        optimizer.swap_swa_sgd()
        optimizer.bn_update(train_loader, model, device)
        loss_swa, acc_val_swa = testing(args, model, device, test_loader)

        snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f_swaAcc_%.5f' % (
            epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0])
        torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth'))

    print('Best ac:%f' % best_acc_val)
예제 #8
0
			swa.step()
			train_writer.add_scalar(tag="learning_rate", scalar_value=scheduler.get_lr()[0], global_step=global_step)
			train_writer.add_scalar(tag="BinaryLoss", scalar_value=loss.item(), global_step=global_step)
			train_writer.add_scalar(tag="SoftMaxLoss", scalar_value=loss4class.item(), global_step=global_step)
			train_bar.set_postfix_str(f"Loss = {loss.item()}")
			try:
				train_writer.add_scalar(tag="idrnd_score", scalar_value=idrnd_score_pytorch(label, output), global_step=global_step)
				train_writer.add_scalar(tag="far_score", scalar_value=far_score(label, output), global_step=global_step)
				train_writer.add_scalar(tag="frr_score", scalar_value=frr_score(label, output), global_step=global_step)
				train_writer.add_scalar(tag="accuracy", scalar_value=bce_accuracy(label, output), global_step=global_step)
			except Exception:
				pass

		if (epoch > config['swa_start'] and epoch % 2 == 0) or (epoch == config['number_epochs']-1):
			swa.swap_swa_sgd()
			swa.bn_update(train_loader, model, device)
			swa.swap_swa_sgd()

		model.eval()
		val_bar = tqdm(val_loader)
		val_bar.set_description_str(desc=f"N epochs - {epoch}")
		outputs = []
		targets = []
		user_ids = []
		frames = []
		for step, batch in enumerate(val_bar):
			image = batch['image'].to(device)
			label4class = batch['label0'].to(device)
			label = batch['label1']
			user_id = batch['user_id']
			frame = batch['frame']
예제 #9
0
    def train(self, train_inputs):
        config = self.config.fitting
        model = train_inputs['model']
        train_data = train_inputs['train_data']
        dev_data = train_inputs['dev_data']
        epoch_start = train_inputs['epoch_start']

        train_steps = int((len(train_data) + config.batch_size - 1) / config.batch_size)
        train_dataloader = DataLoader(train_data,
                                      batch_size=config.batch_size,
                                      collate_fn=self.get_collate_fn('train'),
                                      shuffle=True)
        params_lr = []
        for key, value in model.get_params().items():
            if key in config.lr:
                params_lr.append({"params": value, 'lr': config.lr[key]})
        optimizer = torch.optim.Adam(params_lr)
        optimizer = SWA(optimizer)

        early_stopping = EarlyStopping(model, ROOT_WEIGHT, mode='max', patience=3)
        learning_schedual = LearningSchedual(optimizer, config.epochs, config.end_epoch, train_steps, config.lr)

        aux = ModelAux(self.config, train_steps)
        moving_log = MovingData(window=100)

        ending_flag = False
        detach_flag = False
        swa_flag = False
        fgm = FGM(model)
        for epoch in range(epoch_start, config.epochs):
            for step, (inputs, targets, others) in enumerate(train_dataloader):
                inputs = dict([(key, value[0].cuda() if value[1] else value[0]) for key, value in inputs.items()])
                targets = dict([(key, value.cuda()) for key, value in targets.items()])
                if epoch > 0 and step == 0:
                    model.detach_ptm(False)
                    detach_flag = False
                if epoch == 0 and step == 0:
                    model.detach_ptm(True)
                    detach_flag = True
                # train ================================================================================================
                preds = model(inputs, en_decode=config.verbose)
                loss = model.cal_loss(preds, targets, inputs['mask'])
                loss['back'].backward()

                # 对抗训练
                if (not detach_flag) and config.en_fgm:
                    fgm.attack(emb_name='word_embeddings')  # 在embedding上添加对抗扰动
                    preds_adv = model(inputs, en_decode=False)
                    loss_adv = model.cal_loss(preds_adv, targets, inputs['mask'])
                    loss_adv['back'].backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                    fgm.restore(emb_name='word_embeddings')  # 恢复embedding参数

                # torch.nn.utils.clip_grad_norm(model.parameters(), 1)
                optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    logs = {}
                    if config.verbose:
                        pred_entity_point = model.find_entity(preds['pred'], others['raw_text'])
                        cn, pn, tn = self.calculate_f1(pred_entity_point, others['raw_entity'])
                        metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1,
                                        'correct_num': cn, 'pred_num': pn,
                                        'true_num': tn}
                        moving_data = moving_log(epoch * train_steps + step, metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data['sampled_num']
                        logs['precise'], logs['recall'], logs['f1'] = calculate_f1(moving_data['correct_num'],
                                                                                   moving_data['pred_num'],
                                                                                   moving_data['true_num'],
                                                                                   verbose=True)
                    else:
                        metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1}
                        moving_data = moving_log(epoch * train_steps + step, metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data['sampled_num']
                    # update lr
                    lr_data = learning_schedual.update_lr(epoch, step)
                    logs.update(lr_data)

                    if step + 1 == train_steps:
                        model.eval()
                        aux.new_line()

                        # dev ==========================================================================================

                        eval_inputs = {'model': model,
                                       'data': dev_data,
                                       'type_data': 'dev',
                                       'outfile': train_inputs['dev_res_file']}
                        dev_result = self.eval(eval_inputs)
                        logs['dev_loss'] = dev_result['loss']
                        logs['dev_precise'] = dev_result['precise']
                        logs['dev_recall'] = dev_result['recall']
                        logs['dev_f1'] = dev_result['f1']
                        if logs['dev_f1'] > 0.80:
                            torch.save(model.state_dict(),
                                       "{}/auto_save_{:.6f}.ckpt".format(ROOT_WEIGHT, logs['dev_f1']))
                        if (epoch > 3 or swa_flag) and config.en_swa:
                            optimizer.update_swa()
                            swa_flag = True
                        early_stop, best_score = early_stopping(logs['dev_f1'])

                        # test =========================================================================================
                        if (epoch + 1 == config.epochs and step + 1 == train_steps) or early_stop:
                            ending_flag = True
                            if swa_flag:
                                optimizer.swap_swa_sgd()
                                optimizer.bn_update(train_dataloader, model)

                        model.train()
                aux.show_log(epoch, step, logs)
                if ending_flag:
                    return best_score
예제 #10
0
        args.max_grad_norm,
        writer,
        use_unlab=not args.supervised_only,
    )

    # Save checkpoint
    if (epoch % args.save_freq == 0):
        print('Saving...')
        state = {'net': net.state_dict(), 'epoch': epoch, 'means': means}
        os.makedirs(args.ckptdir, exist_ok=True)
        torch.save(state, os.path.join(args.ckptdir, str(epoch) + '.pt'))

    # Save samples and data
    if epoch % args.eval_freq == 0:
        utils.test_classifier(epoch, net, testloader, device, loss_fn, writer)
        if args.swa:
            optimizer.swap_swa_sgd()
            print("updating bn")
            SWA.bn_update(bn_loader, net)
            utils.test_classifier(epoch,
                                  net,
                                  testloader,
                                  device,
                                  loss_fn,
                                  writer,
                                  postfix="_swa")
        os.makedirs(args.ckptdir, exist_ok=True)

        if args.swa:
            optimizer.swap_swa_sgd()
예제 #11
0
for epoch in range(args.epochs):
    train_loss = train(net, train_loader, optimizer, criterion,
                       args.mixup_prob, args.mixup_alpha, args.cutmix_prob,
                       args.cutmix_alpha, args.grad_clipping)
    valid_loss, kaggle = valid(net, valid_loader, criterion, NUM_TASK)

    show_simple_stats(log, epoch, optimizer, start_timer, kaggle, train_loss,
                      valid_loss)

    if kaggle[1] > best_metric:
        best_metric = kaggle[1]
        torch.save(net.state_dict(), out_dir + '/checkpoint/best_model.pth')

    # learning rate scheduler -------------
    if epoch < args.epochs - 1:  # Prevent step on last epoch
        scheduler_step(args.scheduler, scheduler, optimizer, epoch)

torch.save(net.state_dict(), out_dir + '/checkpoint/last_model.pth')

if args.apply_swa:
    torch.save(
        optimizer.state_dict(), args.output_dir + "/optimizer_" +
        args.model_name + "_last_before_swap.pt")
    optimizer.swap_swa_sgd()
    optimizer.bn_update(train_loader, net, device='cuda')
    torch.save(
        net.state_dict(),
        args.output_dir + "/model_" + args.model_name + "_last_bn_update.pt")

log.write('\n')
예제 #12
0
def train(opt):
    if torch.cuda.is_available():
        # num_gpus = torch.cuda.device_count()
        device = 'cuda'
        torch.cuda.manual_seed(123)
        num_gpus = 1
    else:
        num_gpus = 1
        device = 'cpu'
        torch.manual_seed(123)

    training_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": True,
        "drop_last": True,
        "collate_fn": collater_train,
        "num_workers": opt.num_worker,
        "pin_memory": True
    }

    test_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": False,
        "drop_last": False,
        "collate_fn": collater_test,
        "num_workers": opt.num_worker,
        "pin_memory": True
    }

    train_dataset = VOCDetection(
        train=True,
        root=opt.train_dataset_root,
        transform=train_transform(
            width=EFFICIENTDET[opt.network]['input_size'],
            height=EFFICIENTDET[opt.network]['input_size'],
            lamda_norm=False))

    test_dataset = VOCDetection(
        train=False,
        root=opt.test_dataset_root,
        transform=transforms.Compose([
            Normalizer(lamda_norm=False, grey_p=0.0),
            Resizer(EFFICIENTDET[opt.network]['input_size'])
        ]))

    test_dataset_grey = VOCDetection(
        train=False,
        root=opt.test_dataset_root,
        transform=transforms.Compose([
            Normalizer(lamda_norm=False, grey_p=1.0),
            Resizer(EFFICIENTDET[opt.network]['input_size'])
        ]))

    train_generator = DataLoader(train_dataset, **training_params)
    test_generator = DataLoader(test_dataset, **test_params)
    test_grey_generator = DataLoader(test_dataset_grey, **test_params)

    network_id = int(''.join(filter(str.isdigit, opt.network)))
    loss_func = FocalLoss(alpha=opt.alpha,
                          gamma=opt.gamma,
                          smoothing_factor=opt.smoothing_factor)
    model = EfficientDet(MODEL_MAP[opt.network],
                         image_size=[
                             EFFICIENTDET[opt.network]['input_size'],
                             EFFICIENTDET[opt.network]['input_size']
                         ],
                         num_classes=train_dataset.num_classes(),
                         compound_coef=network_id,
                         num_anchors=9,
                         advprop=True,
                         from_pretrain=opt.from_pretrain)
    anchors_finder = Anchors()

    model.to(device)

    if opt.resume is not None:
        _ = resume(model, device, opt.resume)

    model = nn.DataParallel(model)

    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)

    os.makedirs(opt.log_path)

    if not os.path.isdir(opt.checkpoint_root_dir):
        os.makedirs(opt.checkpoint_root_dir)

    writer = SummaryWriter(opt.log_path)

    base_optimizer = torch.optim.Adam(model.parameters(),
                                      lr=opt.lr,
                                      weight_decay=opt.weight_decay,
                                      amsgrad=True)
    # optimizer = base_optimizer
    optimizer = SWA(base_optimizer, swa_start=10, swa_freq=5)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    if opt.resume is not None:
        model.eval()

        loss_regression_ls = []
        loss_classification_ls = []
        with torch.no_grad():
            for iter, data in enumerate(tqdm(test_generator)):
                if torch.cuda.is_available():
                    anchors = anchors_finder(data['image'].cuda().float())
                    classification, regression = model(
                        data['image'].cuda().float())
                    cls_loss, reg_loss = loss_func(classification, regression,
                                                   anchors,
                                                   data['annots'].cuda())
                else:
                    anchors = anchors_finder(data['image'].float())
                    classification, regression = model(data['image'].float())
                    cls_loss, reg_loss = loss_func(classification, regression,
                                                   anchors, data['annots'])

                cls_loss = cls_loss.sum()
                reg_loss = reg_loss.sum()

                loss_classification_ls.append(float(cls_loss))
                loss_regression_ls.append(float(reg_loss))

        cls_loss = np.sum(loss_classification_ls) / test_dataset.__len__()
        reg_loss = np.sum(loss_regression_ls) / test_dataset.__len__()
        loss = (reg_loss + cls_loss) / 2

        writer.add_scalars('Total_loss', {'test': loss}, 0)
        writer.add_scalars('Regression_loss', {'test': reg_loss}, 0)
        writer.add_scalars('Classfication_loss (focal loss)',
                           {'test': cls_loss}, 0)

        print(
            'Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
            .format(0, opt.num_epochs, cls_loss, reg_loss, np.mean(loss)))

        mAP_1, _ = evaluate(test_generator,
                            model,
                            iou_threshold=0.5,
                            score_threshold=0.5)
        mAP_5, _ = evaluate(test_generator,
                            model,
                            iou_threshold=0.75,
                            score_threshold=0.1)

        writer.add_scalars(
            'mAP', {
                'score threshold 0.5; iou threshold {}'.format(0.5): mAP_1,
            }, 0)
        writer.add_scalars(
            'mAP', {
                'score threshold 0.1 ; iou threshold {}'.format(0.75): mAP_5,
            }, 0)

        mAP_1_grey, _ = evaluate(test_grey_generator,
                                 model,
                                 iou_threshold=0.5,
                                 score_threshold=0.5)
        mAP_5_grey, _ = evaluate(test_grey_generator,
                                 model,
                                 iou_threshold=0.75,
                                 score_threshold=0.1)

        writer.add_scalars(
            'mAP', {
                'grey: True; score threshold 0.5; iou threshold {}'.format(0.5):
                mAP_1_grey,
            }, 0)
        writer.add_scalars(
            'mAP', {
                'grey: True; score threshold 0.1 ; iou threshold {}'.format(0.75):
                mAP_5_grey,
            }, 0)

    model.train()

    num_iter_per_epoch = len(train_generator)
    train_iter = 0
    best_eval_loss = 10.0
    for epoch in range(opt.num_epochs):
        epoch_loss = []
        bn_update_data_list = []
        progress_bar = tqdm(train_generator)
        for iter, data in enumerate(progress_bar):
            scheduler.step(epoch + iter / train_generator.__len__())
            optimizer.zero_grad()

            if torch.cuda.is_available():
                if iter == 0:
                    bn_update_data_list.append(data['image'].float())
                anchors = anchors_finder(data['image'].cuda().float())
                classification, regression = model(
                    data['image'].cuda().float())
                cls_loss, reg_loss = loss_func(classification, regression,
                                               anchors, data['annots'].cuda())
            else:
                if iter == 0:
                    bn_update_data_list.append(data['image'].float())
                anchors = anchors_finder(data['image'].float())
                classification, regression = model(data['image'].float())
                cls_loss, reg_loss = loss_func(classification, regression,
                                               anchors, data['annots'])

            cls_loss = cls_loss.mean()
            reg_loss = reg_loss.mean()
            loss = (reg_loss + cls_loss) / 2

            if loss == 0:
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           opt.glip_threshold)
            optimizer.step()
            epoch_loss.append(float(loss))
            total_loss = np.mean(epoch_loss)
            train_iter += 1

            progress_bar.set_description(
                'Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. loss: {:.5f} Total loss: {:.5f}'
                .format(epoch + 1, opt.num_epochs, iter + 1,
                        num_iter_per_epoch, cls_loss, reg_loss, loss,
                        total_loss))
            writer.add_scalars('Total_loss', {'train': total_loss}, train_iter)
            writer.add_scalars('Regression_loss', {'train': reg_loss},
                               train_iter)
            writer.add_scalars('Classfication_loss (focal loss)',
                               {'train': cls_loss}, train_iter)

        if (epoch + 1) % opt.test_interval == 0 and epoch + 1 >= 0:

            loss_regression_ls = []
            loss_classification_ls = []
            optimizer.swap_swa_sgd()
            optimizer.bn_update(bn_update_data_list, model)
            model.eval()

            with torch.no_grad():
                for iter, data in enumerate(tqdm(test_generator)):

                    if torch.cuda.is_available():
                        anchors = anchors_finder(data['image'].cuda().float())
                        classification, regression = model(
                            data['image'].cuda().float())
                        cls_loss, reg_loss = loss_func(classification,
                                                       regression, anchors,
                                                       data['annots'].cuda())

                    else:
                        anchors = anchors_finder(data['image'].float())
                        classification, regression = model(
                            data['image'].float())
                        cls_loss, reg_loss = loss_func(classification,
                                                       regression, anchors,
                                                       data['annots'])

                    cls_loss = cls_loss.sum()
                    reg_loss = reg_loss.sum()

                    loss_classification_ls.append(float(cls_loss))
                    loss_regression_ls.append(float(reg_loss))

            cls_loss = np.sum(loss_classification_ls) / test_dataset.__len__()
            reg_loss = np.sum(loss_regression_ls) / test_dataset.__len__()
            loss = (reg_loss + cls_loss) / 2

            writer.add_scalars('Total_loss', {'test': loss}, train_iter)
            writer.add_scalars('Regression_loss', {'test': reg_loss},
                               train_iter)
            writer.add_scalars('Classfication_loss (focal loss)',
                               {'test': cls_loss}, train_iter)

            print(
                'Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                .format(epoch + 1, opt.num_epochs, cls_loss, reg_loss,
                        np.mean(loss)))

            if 0 < loss < best_eval_loss and not (epoch +
                                                  1) % opt.eval_interval == 0:
                best_eval_loss = loss

                mAP_1, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.5,
                                    score_threshold=0.5)
                mAP_5, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.75,
                                    score_threshold=0.1)

                writer.add_scalars('mAP', {
                    'score threshold 0.5; iou threshold {}'.format(0.5):
                    mAP_1,
                }, train_iter)
                writer.add_scalars('mAP', {
                    'score threshold 0.1 ; iou threshold {}'.format(0.75):
                    mAP_5,
                }, train_iter)

                mAP_1_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.5,
                                         score_threshold=0.5)
                mAP_5_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.75,
                                         score_threshold=0.1)

                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.5; iou threshold {}'.format(0.5):
                        mAP_1_grey,
                    }, train_iter)
                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.1 ; iou threshold {}'.format(0.75):
                        mAP_5_grey,
                    }, train_iter)

                if torch.cuda.device_count() > 1:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.module.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))
                else:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))

            if (epoch + 1) % opt.eval_interval == 0 and epoch + 1 >= 0:
                mAP_1, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.5,
                                    score_threshold=0.5)
                mAP_5, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.75,
                                    score_threshold=0.1)

                writer.add_scalars('mAP', {
                    'score threshold 0.5; iou threshold {}'.format(0.5):
                    mAP_1,
                }, train_iter)
                writer.add_scalars('mAP', {
                    'score threshold 0.1 ; iou threshold {}'.format(0.75):
                    mAP_5,
                }, train_iter)

                mAP_1_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.5,
                                         score_threshold=0.5)
                mAP_5_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.75,
                                         score_threshold=0.1)

                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.5; iou threshold {}'.format(0.5):
                        mAP_1_grey,
                    }, train_iter)
                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.1 ; iou threshold {}'.format(0.75):
                        mAP_5_grey,
                    }, train_iter)

                if torch.cuda.device_count() > 1:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.module.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))
                else:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))

            optimizer.swap_swa_sgd()

            model.train()

        scheduler.step()

    writer.close()