示例#1
0
                }
                , model_path)

        # update loggers
        writer.add_scalars('Loss/', {'train loss': err,
                                          'valid loss': valid_err}, epoch)

    writer.close()


if __name__ == '__main__':
    # argument setting
    train_args = train_argument()

    # config
    model_config(train_args, save=True)     # save model configuration before training

    # set cuda
    torch.cuda.set_device(train_args.gpu_id)

    # model
    model = model_builder(train_args.model_name, train_args.scale, *train_args.model_args).cuda()

    # optimizer and criteriohn
    optimizer = optim.Adam(model.parameters(), lr=train_args.lr)
    criterion = nn.MSELoss()

    # dataset
    full_set = AxisDataSet(train_args.train_path, train_args.target_path)

    # build hold out CV
示例#2
0
def train(model, train_loader, valid_loader, optimizer, criterion, args):
    # content_loss
    best_err = None
    feature_extractor = FeatureExtractor().cuda()
    feature_extractor.eval()

    writer, log_path = writer_builder(args.log_path,
                                      args.model_name,
                                      load=args.load)

    # init data
    checkpoint = {
        'epoch': 1,  # start from 1
        'train_iter': 0,  # train iteration
        'valid_iter': 0,  # valid iteration
    }
    model_path = os.path.join(log_path, f'{args.model_name}_{args.scale}x.pt')

    # config
    model_config(train_args,
                 save=log_path)  # save model configuration before training

    # load model from exist .pt file
    if args.load and os.path.isfile(model_path):
        r"""
        load a pickle file from exist parameter

        state_dict: model's state dict
        epoch: parameters were updated in which epoch
        """
        checkpoint = torch.load(model_path, map_location=f'cuda:{args.gpu_id}')
        checkpoint['epoch'] += 1  # start from next epoch
        checkpoint['train_iter'] += 1
        checkpoint['valid_iter'] += 1
        model.load_state_dict(checkpoint['state_dict'])

    # initialize the early_stopping object
    if args.early_stop:
        early_stopping = EarlyStopping(patience=args.patience,
                                       threshold=args.threshold,
                                       verbose=args.verbose,
                                       path=model_path)

    if args.scheduler:
        scheduler = schedule_builder(optimizer, args.scheduler, args.step,
                                     args.factor)

    # progress bar postfix value
    pbar_postfix = {
        'MSE loss': 0.0,
        'Content loss': 0.0,
        'lr': args.lr,
    }

    for epoch in range(checkpoint['epoch'], args.epochs + 1):
        model.train()
        err = 0.0
        valid_err = 0.0

        train_bar = tqdm(train_loader,
                         desc=f'Train epoch: {epoch}/{args.epochs}')
        for data in train_bar:
            # load data from data loader
            inputs, target, _ = data
            inputs, target = inputs.cuda(), target.cuda()

            # predicted fixed 6 axis data
            pred = model(inputs)

            # MSE loss
            mse_loss = args.alpha * criterion(pred - inputs, target - inputs)

            # content loss
            gen_features = feature_extractor(pred)
            real_features = feature_extractor(target)
            content_loss = args.beta * criterion(gen_features, real_features)

            # for compatible but bad for memory usage
            loss = mse_loss + content_loss

            # update progress bar
            pbar_postfix['MSE loss'] = mse_loss.item()
            pbar_postfix['Content loss'] = content_loss.item()

            # show current lr
            if args.scheduler:
                pbar_postfix['lr'] = optimizer.param_groups[0]['lr']

            train_bar.set_postfix(pbar_postfix)

            err += loss.sum().item() * inputs.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update writer
            writer.add_scalar('Iteration/train loss',
                              loss.sum().item(), checkpoint['train_iter'])
            checkpoint['train_iter'] += 1

        # cross validation
        valid_bar = tqdm(valid_loader,
                         desc=f'Valid epoch:{epoch}/{args.epochs}',
                         leave=False)
        model.eval()
        input_epoch = pred_epoch = target_epoch = torch.empty(0, 0)
        with torch.no_grad():
            for data in valid_bar:
                # for data in valid_loader:
                inputs, target, _ = data
                inputs, target = inputs.cuda(), target.cuda()

                pred = model(inputs)

                # MSE loss
                mse_loss = criterion(pred - inputs, target - inputs)

                # content loss
                gen_features = feature_extractor(pred)
                real_features = feature_extractor(target)
                content_loss = criterion(gen_features, real_features)

                # for compatible
                loss = mse_loss + content_loss

                # update progress bar
                pbar_postfix['MSE loss'] = mse_loss.item()
                pbar_postfix['Content loss'] = content_loss.item()

                # show current lr
                if args.scheduler:
                    pbar_postfix['lr'] = optimizer.param_groups[0]['lr']

                valid_bar.set_postfix(pbar_postfix)

                valid_err += loss.sum().item() * inputs.size(0)

                # update writer
                writer.add_scalar('Iteration/valid loss',
                                  loss.sum().item(), checkpoint['valid_iter'])
                checkpoint['valid_iter'] += 1

                # out2csv every check interval epochs (default: 5)
                if epoch % args.check_interval == 0:
                    input_epoch = inputs
                    pred_epoch = pred
                    target_epoch = target

        # out2csv every check interval epochs (default: 5)
        if epoch % args.check_interval == 0:

            # tensor to csv file
            out2csv(input_epoch, f'{epoch}', 'input', args.out_num,
                    args.save_path, args.stroke_length)
            out2csv(pred_epoch, f'{epoch}', 'output', args.out_num,
                    args.save_path, args.stroke_length)
            out2csv(target_epoch, f'{epoch}', 'target', args.out_num,
                    args.save_path, args.stroke_length)

        # compute loss
        err /= len(train_loader.dataset)
        valid_err /= len(valid_loader.dataset)
        print(f'\ntrain loss: {err:.4f}, valid loss: {valid_err:.4f}')

        # update scheduler
        if args.scheduler:
            scheduler.step()

        # update loggers
        writer.add_scalars(
            'Epoch',
            {
                'train loss': err,
                'valid loss': valid_err
            },
            epoch,
        )

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        if args.early_stop:
            early_stopping(valid_err, model, epoch)

            if early_stopping.early_stop:
                print("Early stopping")
                break
        # if early stop is false, store model when the err is lowest
        elif epoch == checkpoint['epoch'] or err < best_err:
            best_err = err  # save err in first epoch

            # save current epoch and model parameters
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'epoch': epoch,
                    'train_iter': checkpoint['train_iter'],
                    'valid_iter': checkpoint['valid_iter'],
                }, model_path)

    writer.close()
示例#3
0
        # for compatible
        loss = content_loss + mse_loss
        err += loss.sum().item() * inputs.size(0)

    err /= len(test_loader.dataset)
    print(f'test error:{err:.4f}')


if __name__ == '__main__':
    # argument setting
    test_args = test_argument()

    if test_args.doc:
        test_args = config_loader(test_args.doc, test_args)
    # config
    model_config(test_args,
                 save=False)  # print model configuration of evaluation

    # set cuda
    torch.cuda.set_device(test_args.gpu_id)

    # model
    model = model_builder(test_args.model_name, test_args.scale,
                          **test_args.model_args).cuda()

    # criteriohn
    criterion = criterion_builder(test_args.criterion)
    # optimizer = None # don't need optimizer in test

    # dataset
    test_set = AxisDataSet(test_args.test_path, test_args.target_path)
示例#4
0
        num_workers=train_args.num_workers,
        #   sampler=train_sampler,
        pin_memory=False,
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=train_args.batch_size,
        shuffle=False,
        num_workers=train_args.num_workers,
        #   sampler=valid_sampler,
        pin_memory=False,
    )

    # model summary
    data, _, _ = train_set[0]
    summary(
        model,
        tuple(data.shape),
        batch_size=train_args.batch_size,
        device='cuda',
        model_name=train_args.model_name.upper(),
    )

    # training
    train(model, train_loader, valid_loader, optimizer, criterion, train_args)

    # config
    model_config(train_args,
                 save=False)  # print model configuration after training

    postprocessor(train_args.save_path)
示例#5
0
def main():
    start_epoch = 0

    if args.metric:
        save_model = "./save_model_" + args.dataset + "_metric"
        tensorboard_dir = "./tensorboard/OOD_" + args.dataset
    else:
        save_model = "./save_model_" + args.dataset
        tensorboard_dir = "./tensorboard/OOD_" + args.dataset

    #### create folder
    Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True,
                                                             parents=True)

    if args.board_clear: board_clear(tensorboard_dir)

    idx = tensorboard_idx(tensorboard_dir)
    summary = SummaryWriter(os.path.join(tensorboard_dir, str(idx)))

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Hyper-parameters
    eps = 1e-8

    ### data config
    resize = (160, 160)
    train_dataset, train_loader, test_dataset, test_loader, out_test_dataset, out_test_loader, OOD_dataset, OOD_loader = data_config(
        image_dir, OOD_dir, args.num_classes, args.OOD_num_classes,
        args.batch_size, args.num_instances, args.soft_label,
        args.custom_sampler, args.not_test_ODIN, args.transfer, resize)

    ##### model, optimizer config
    model = model_config(args.net_type, args.num_classes, args.OOD_num_classes)

    #### batch_number
    batch_num = len(
        train_loader) / args.batch_size if args.custom_sampler else len(
            train_loader)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.init_lr,
                          momentum=0.9,
                          nesterov=args.nesterov)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.num_epochs * batch_num, eta_min=args.init_lr / 10)

    if args.resume:
        print("load checkpoint_last")
        checkpoint = torch.load(
            os.path.join(save_model, env, args.net_type,
                         'checkpoint_last.pth.tar'))

        ##### load model
        model.load_state_dict(checkpoint["model"])
        start_epoch = checkpoint["epoch"]
        optimizer = optim.SGD(model.parameters(), lr=checkpoint["init_lr"])

    #### loss config
    criterion = nn.BCEWithLogitsLoss()
    triplet = torch.nn.TripletMarginLoss(margin=0.5, p=2)

    # Start training
    j = 0
    best_score = 0
    score = 0
    triplet_loss = torch.tensor(0)  # for error control
    membership_loss = torch.tensor(0)
    transfer_loss = torch.tensor(0)
    for epoch in range(start_epoch, args.num_epochs):
        OOD_data = 0  # for error control
        total_loss = 0
        triplet_running_loss = 0
        membership_running_loss = 0
        transfer_running_loss = 0
        class_running_loss = 0
        train_acc = 0
        test_acc = 0
        stime = time.time()

        for i, train_data in enumerate(train_loader):
            # for i, (train_data, OOD_data) in enumerate(zip(train_loader, OOD_loader)):
            #### initialized
            model = model.to(device).train()
            optimizer.zero_grad()

            org_image = train_data['input'] + 0.01 * torch.randn_like(
                train_data['input'])
            org_image = org_image.to(device)
            gt = train_data['label'].type(torch.FloatTensor).to(device)

            #### forward path
            output, output_list = model.feature_list(org_image)

            #### calc loss
            if args.transfer:
                transfer_loss = Transfer_Loss(model, OOD_data, criterion,
                                              args.num_classes, device)

            if args.metric:
                triplet_loss = Metric_Loss(output_list, gt, triplet)

            if args.membership:
                membership_loss = Membership_Loss(output, gt, args.num_classes)

            class_loss = criterion(output, gt)

            #### backpropagation
            total_backward_loss = class_loss + triplet_loss + membership_loss + transfer_loss
            total_backward_loss.backward()
            optimizer.step()
            scheduler.step()

            #### calc accuracy and running loss update
            train_acc += sum(
                torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(
                    gt, dim=1)).cpu().detach().item()

            class_running_loss += class_loss.item()
            triplet_running_loss += triplet_loss.item()
            membership_running_loss += membership_loss.item()
            transfer_running_loss += transfer_loss.item()
            total_loss += total_backward_loss.item()

        #### test_classification
        with torch.no_grad():
            test_label, test_acc = test(model, test_loader, args.num_classes,
                                        device)

        #### print status
        print(
            'Epoch [{}/{}], Step {}, exe time: {:.2f}, lr: {:.4f}*e-4'.format(
                epoch, args.num_epochs, i + 1,
                time.time() - stime,
                scheduler.get_last_lr()[0] * 10**4))

        print(
            'class_loss = {:.4f}, membership_loss = {:.4f}, transfer_loss = {:.4f}, total_loss = {:.4f}'
            .format(class_running_loss / batch_num,
                    membership_running_loss / batch_num,
                    transfer_running_loss / batch_num, total_loss / batch_num))

        if args.dataset == 'caltech' or args.dataset == 'dog':
            print("train accuracy total : {:.4f}".format(
                train_acc / train_dataset.num_image))
        else:
            print("train accuracy total : {:.4f}".format(
                train_acc / (batch_num * args.batch_size)))
        print("test accuracy total : {:.4f}".format(test_acc /
                                                    test_dataset.num_image))
        #### class-wise test accuracy
        for label in range(args.num_classes):
            print("label{}".format(label), end=" ")
            print("{:.4f}%".format(test_label[label] /
                                   test_dataset.len_list[label] * 100),
                  end=" ")
        print()
        print()

        #### test ODIN
        if epoch % 10 == 9 and args.not_test_ODIN:
            best_TNR, best_AUROC = test_ODIN(model, test_loader,
                                             out_test_loader, args.net_type,
                                             args)
            summary.add_scalar('AD_acc/AUROC', best_AUROC, epoch)
            summary.add_scalar('AD_acc/TNR', best_TNR, epoch)

        #### update tensorboard
        summary.add_scalar('loss/loss', total_loss / batch_num, epoch)
        summary.add_scalar('loss/membership_loss',
                           membership_running_loss / batch_num, epoch)
        summary.add_scalar('loss/transfer_loss',
                           transfer_running_loss / batch_num, epoch)
        summary.add_scalar('acc/train_acc',
                           train_acc / train_dataset.num_image, epoch)
        summary.add_scalar('acc/test_acc', test_acc / test_dataset.num_image,
                           epoch)
        summary.add_scalar("learning_rate/lr",
                           scheduler.get_last_lr()[0], epoch)
        time.sleep(0.001)

        #### save model
        torch.save(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'init_lr': scheduler.get_last_lr()[0]
            },
            os.path.join(save_model, env, args.net_type,
                         'checkpoint_last.pth.tar'))
        scheduler.step()