Esempio n. 1
0
def demo_test(args):
    if args.doc:
        args = config_loader(args.doc, args)
    # config
    # model_config(args, save=False)     # print model configuration of evaluation

    # set cuda
    torch.cuda.set_device(args.gpu_id)

    # model
    model = model_builder(args.model_name, args.scale, **args.model_args).cuda()

    # criteriohn
    criterion = criterion_builder(args.criterion)

    # dataset
    test_set = AxisDataSet(args.test_path, args.target_path)

    test_loader = DataLoader(test_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                            #  pin_memory=True,
                             pin_memory=False,
                             )

    # test
    test(model, test_loader, criterion, args)
Esempio n. 2
0
def model_env(args):
    """building model environment avoiding to instantiate model.

    Args:
        args : model arguments which is control by demo_utils.argument_setting

    Returns:
        model (torch.nn): build model in cuda device
        criterion(torch.nn): build criterion. Default to mse loss
        extractor(torch.nn): build vgg content loss in cuda device
    """

    if args.doc:
        args = config_loader(args.doc, args)

    # set cuda device
    torch.cuda.set_device(args.gpu_id)

    # model version control
    version = args.load if type(args.load) is int else 0

    # model path and parameter
    model_path = os.path.join(
        args.log_path, args.model_name, f'version_{version}',f'{args.model_name}_{args.scale}x.pt')

    checkpoint = torch.load(model_path, map_location=f'cuda:{args.gpu_id}')

    # loading model
    model = model_builder(args.model_name, args.scale, **args.model_args).cuda()
    model.load_state_dict(checkpoint['state_dict'])

    # build criterion
    criterion = criterion_builder(args.criterion)

    # loading feature extractor
    extractor = FeatureExtractor().cuda() if args.content_loss else None

    return model, criterion, extractor
Esempio n. 3
0
    if train_args.doc:
        train_args = config_loader(train_args.doc, train_args)
    # set cuda
    torch.cuda.set_device(train_args.gpu_id)

    # model
    model = model_builder(train_args.model_name, train_args.scale,
                          **train_args.model_args).cuda()

    # optimizer and critera
    optimizer = optimizer_builder(train_args.optim)  # optimizer class
    optimizer = optimizer(  # optmizer instance
        model.parameters(),
        lr=train_args.lr,
        weight_decay=train_args.weight_decay)
    criterion = criterion_builder(train_args.criterion)

    # dataset
    full_set = AxisDataSet(train_args.train_path, train_args.target_path)

    # build hold out CV
    train_set, valid_set = cross_validation(
        full_set,
        mode='hold',
        p=train_args.holdout_p,
    )

    # dataloader
    train_loader = DataLoader(
        train_set,
        batch_size=train_args.batch_size,
Esempio n. 4
0
    if test_args.doc:
        test_args = config_loader(test_args.doc, test_args)
    # config
    model_config(test_args,
                 save=False)  # print model configuration of evaluation

    # set cuda
    torch.cuda.set_device(test_args.gpu_id)

    # model
    model = model_builder(test_args.model_name, test_args.scale,
                          **test_args.model_args).cuda()

    # criteriohn
    criterion = criterion_builder(test_args.criterion)
    # optimizer = None # don't need optimizer in test

    # dataset
    test_set = AxisDataSet(test_args.test_path, test_args.target_path)

    test_loader = DataLoader(
        test_set,
        batch_size=test_args.batch_size,
        shuffle=False,
        num_workers=test_args.num_workers,
        pin_memory=False,
    )

    # test
    test(model, test_loader, criterion, test_args)
Esempio n. 5
0
def train(args,writer):
    train_loss_curve = []
    train_wrmse_curve = []
    valid_loss_curve = []
    valid_wrmse_curve = []

    model = model_builder(args.model).to(device)
    model.train()
    data = pd.read_csv('train.csv', encoding='utf-8')
    if args.fillna:
        print("fill nan with K!!")
        data = FILLNA('train.csv')
    if args.data_aug:
        from data_augment import data_aug
        print("It may cost time!!")
        data = data_aug(data)
        print("Augmentation Complete!! Please check data_augment.csv")
    dataset = MLDataset(data)
    train_size = int(0.8 * len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
    valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=True)

    # loss function and optimizer
    # can change loss function and optimizer you want
    criterion  = criterion_builder(args.criterion)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr,amsgrad=True)
    if args.scheduler:
        scheduler = schedule_builder(optimizer, args.epochs, args.lr_method)
        print(scheduler)
   
    best = float("inf")
    # start training
    for e in range(args.epochs):
        train_loss = 0.0
        train_wrmse = 0.0
        valid_loss = 0.0
        valid_wrmse = 0.0

        print(f'\nEpoch: {e+1}/{args.epochs}')
        print('-' * len(f'Epoch: {e+1}/{args.epochs}'))
        # tqdm to disply progress bar
        for inputs, labels in tqdm(train_dataloader):
            # data from data_loader
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            outputs = model(inputs)
            # MSE loss and WRMSE
            loss = criterion(outputs, labels)
            wrmse = WRMSE(outputs, labels, device)
            # weights update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if args.scheduler:
                scheduler.step()
            # loss calculate
            train_loss += loss.item()
            train_wrmse += wrmse
        # =================================================================== #
        # If you have created the validation dataset,
        # you can refer to the for loop above and calculate the validation loss
        for inputs, labels in tqdm(valid_dataloader):
            # data from data_loader
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            outputs = model(inputs)
            # MSE loss and WRMSE
            loss = criterion(outputs, labels)
            wrmse = WRMSE(outputs, labels, device)
            # loss calculate
            valid_loss += loss.item()
            valid_wrmse += wrmse


        # =================================================================== #
        # save the best model weights as .pth file
        loss_epoch = train_loss / len(train_dataset)
        wrmse_epoch = math.sqrt(train_wrmse/len(train_dataset))
        valid_loss_epoch = valid_loss / len(valid_dataset)
        valid_wrmse_epoch = math.sqrt(valid_wrmse/len(valid_dataset))
        if valid_loss_epoch < best :
            best = valid_loss_epoch
            torch.save(model.state_dict(), f'{args.save}/mymodel.pth')
        print(f'Training loss: {loss_epoch:.4f}')
        print(f'Training WRMSE: {wrmse_epoch:.4f}')
        print(f'Valid loss: {valid_loss_epoch:.4f}')
        print(f'Valid WRMSE: {valid_wrmse_epoch:.4f}')
        # save loss and wrmse every epoch
        train_loss_curve.append(loss_epoch)
        train_wrmse_curve.append(wrmse_epoch)
        valid_loss_curve.append(valid_loss_epoch)
        valid_wrmse_curve.append(valid_wrmse_epoch)
        if args.tensorboard:
            writer.add_scalar('train/train_loss', loss_epoch, e)
            writer.add_scalar('train/wrmse_loss', wrmse_epoch, e)
            writer.add_scalar('validation/valid_loss', valid_loss_epoch, e)
            writer.add_scalar('validation/wrmse_loss', valid_wrmse_epoch, e)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], e)
    # generate training curve
    visualize(train_loss_curve,valid_loss_curve, 'Train Loss')
    visualize(train_wrmse_curve,valid_wrmse_curve, 'Train WRMSE')
    print("\nBest Validation loss:",best)