Exemplo n.º 1
0
def main():
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    parser = argparse.ArgumentParser()
    parser.add_argument('--image_path', type=str, default="./data/cache/train")
    parser.add_argument('--label_path', type=str, default="./data/cache/train.csv")
    parser.add_argument('--kfold_idx', type=int, default=0)

    # parser.add_argument('--model', type=str, default='CustomModel')
    parser.add_argument('--model', type=str, default='efficientnet-b0')
    parser.add_argument('--epochs', type=int, default=2000)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--patient', type=int, default=8)
    parser.add_argument('--seed', type=int, default=42)

    parser.add_argument('--device', type=str, default=device)
    parser.add_argument('--resume', type=str, default=None)
    parser.add_argument('--comments', type=str, default=None)

    args = parser.parse_args()

    print('=' * 50)
    print('[info msg] arguments\n')
    for key, value in vars(args).items():
        print(key, ":", value)
    print('=' * 50)
    
    assert os.path.isdir(args.image_path), 'wrong path'
    assert os.path.isfile(args.label_path), 'wrong path'
    if (args.resume):
        assert os.path.isfile(args.resume), 'wrong path'
    # assert args.kfold_idx < 5

    seed_everything(args.seed)

    data_df = pd.read_csv(args.label_path)

    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=args.seed)
    for train_idx, valid_idx in sss.split(X=data_df['id'], y=data_df['accent']):
        train_df = data_df.iloc[train_idx]
        valid_df = data_df.iloc[valid_idx]

    train_data = dataset.DaconDataset(
        image_folder=args.image_path,
        label_df=train_df,
     )
    
    valid_data = dataset.DaconDataset(
        image_folder=args.image_path,
        label_df=valid_df,
    )

    train_sampler = get_sampler(
        df=train_df,
        dataset=train_data
    )

    valid_sampler = get_sampler(
        df=valid_df,
        dataset=valid_data
    )

    train_data_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            # shuffle=True,
            sampler=train_sampler
        )

    valid_data_loader = torch.utils.data.DataLoader(
            valid_data,
            batch_size=args.batch_size,
            # shuffle=False,
            sampler=valid_sampler
        )

    model = None

    if args.model == 'CustomModel':
        model = CustomModel()
        print('[info msg] {} model is created\n'.format('CustomModel'))
    else:
        model = EfficientNet.from_pretrained(args.model, in_channels=1, num_classes=6, dropout_rate=0.3, advprop=True)
        print('[info msg] {} model is created\n'.format(args.model))
    
    print('=' * 50)

    if(args.resume):
        model.load_state_dict(torch.load(args.resume))
        print('[info msg] pre-trained weight is loaded !!\n')        
        print(args.resume)
        print('=' * 50)

    if args.device == 'cuda' and torch.cuda.device_count() > 1 :
        model = torch.nn.DataParallel(model)
 
    ##### Wandb ######
    wandb.init(project='dacon_voice')
    wandb.run.name = args.comments
    wandb.config.update(args)
    wandb.watch(model)
    ##################
    
    model.to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = ReduceLROnPlateau(
        optimizer=optimizer,
        mode='min',
        patience=2,
        factor=0.5,
        verbose=True
        )

    train_loss = []
    train_acc = []
    valid_loss = []
    valid_acc = []

    best_loss = float("inf")

    patient = 0

    date_time = datetime.now().strftime("%m%d%H%M%S")
    SAVE_DIR = os.path.join('./model', date_time)

    print('[info msg] training start !!\n')
    startTime = datetime.now()
    for epoch in range(args.epochs):        
        print('Epoch {}/{}'.format(epoch+1, args.epochs))
        train_epoch_loss, train_epoch_acc = trainer.train(
            train_loader=train_data_loader,
            model=model,
            loss_func=criterion,
            device=args.device,
            optimizer=optimizer,
            )
        train_loss.append(train_epoch_loss)
        train_acc.append(train_epoch_acc)

        valid_epoch_loss, valid_epoch_acc = trainer.validate(
            valid_loader=valid_data_loader,
            model=model,
            loss_func=criterion,
            device=args.device,
            scheduler=scheduler,
            )
        valid_loss.append(valid_epoch_loss)        
        valid_acc.append(valid_epoch_acc)

        wandb.log({
            "Train Acc": train_epoch_acc,
            "Valid Acc": valid_epoch_acc,
            "Train Loss": train_epoch_loss,
            "Valid Loss": valid_epoch_loss,
            })

        if best_loss > valid_epoch_loss:
            patient = 0
            best_loss = valid_epoch_loss

            Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(SAVE_DIR, 'model_best.pth'))
            print('MODEL IS SAVED TO {}!!!'.format(date_time))
            
        else:
            patient += 1
            if patient > args.patient - 1:
                print('=======' * 10)
                print("[Info message] Early stopper is activated")
                break

    elapsed_time = datetime.now() - startTime

    train_loss = np.array(train_loss)
    train_acc = np.array(train_acc)
    valid_loss = np.array(valid_loss)
    valid_acc = np.array(valid_acc)

    best_loss_pos = np.argmin(valid_loss)
    
    print('=' * 50)
    print('[info msg] training is done\n')
    print("Time taken: {}".format(elapsed_time))
    print("best loss is {} w/ acc {} at epoch : {}".format(best_loss, valid_acc[best_loss_pos], best_loss_pos))    

    print('=' * 50)
    print('[info msg] {} model weight and log is save to {}\n'.format(args.model, SAVE_DIR))

    with open(os.path.join(SAVE_DIR, 'log.txt'), 'w') as f:
        for key, value in vars(args).items():
            f.write('{} : {}\n'.format(key, value))            

        f.write('\n')
        f.write('total ecpochs : {}\n'.format(str(train_loss.shape[0])))
        f.write('time taken : {}\n'.format(str(elapsed_time)))
        f.write('best_train_loss {} w/ acc {} at epoch : {}\n'.format(np.min(train_loss), train_acc[np.argmin(train_loss)], np.argmin(train_loss)))
        f.write('best_valid_loss {} w/ acc {} at epoch : {}\n'.format(np.min(valid_loss), valid_acc[np.argmin(valid_loss)], np.argmin(valid_loss)))

    plt.figure(figsize=(15,5))
    plt.subplot(1, 2, 1)
    plt.plot(train_loss, label='train loss')
    plt.plot(valid_loss, 'o', label='valid loss')
    plt.axvline(x=best_loss_pos, color='r', linestyle='--', linewidth=1.5)
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(train_acc, label='train acc')
    plt.plot(valid_acc, 'o', label='valid acc')
    plt.axvline(x=best_loss_pos, color='r', linestyle='--', linewidth=1.5)
    plt.legend()
    plt.savefig(os.path.join(SAVE_DIR, 'history.png'))
Exemplo n.º 2
0
    train_set = CustomDataset(dir_csv, dir_img, transforms=augs)
    train_loader = DataLoader(train_set, batch_size=batch_size_train, shuffle=True)

    val_set = CustomDataset(dir_csv, dir_img, transforms=tr)
    val_loader = DataLoader(val_set, batch_size=batch_size_test, shuffle=False)

    model = CustomModel()
    loss_function = CustomLoss()

    model.to(device)
    print('Starting optimizer with LR={}'.format(lr))
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
   
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, num_epochs + 1):
        train(model, device, train_loader, optimizer, epoch, loss_function)
        test(model, device, test_loader, loss_function)
        scheduler.step()

    torch.save(model.state_dict(), "well_trained model.pt")


if __name__ == "__main__":
	num_epochs = 5
	batch_size = 64
	lr = 0.01
	gamma = 0.9
	device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
	main(num_epochs, batch_size, lr, gamma, device)