def train(loaders, save_path):
    """returns trained model"""
    # Initialize custom defined cnn
    model = Net()
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        model.cuda()

    # cross entropy loss for classification task
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf

    n_epochs = config.n_epochs
    for epoch in range(1, n_epochs + 1):
        # initialize variables to monitor training and validation loss
        train_loss = 0.0
        valid_loss = 0.0

        model.train()
        for batch_idx, (data, target) in enumerate(loaders['train']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            # average training loss
            train_loss += (1 / (batch_idx + 1)) * (loss.data - train_loss)

        # vaidation
        model.eval()
        for batch_idx, (data, target) in enumerate(loaders['valid']):
            # move to GPU
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            ## update the average validation loss
            output = model(data)
            loss = criterion(output, target)
            valid_loss += (1 / (batch_idx + 1)) * (loss.data - valid_loss)

        # print training/validation statistics
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.
              format(epoch, train_loss, valid_loss))

        # save the model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            torch.save(model.state_dict(), save_path)

            # Updating the validation loss minimum
            valid_loss_min = valid_loss

    # return trained model
    return model
                                           args.cnn_helmet_resize_width, 
                                           args.cnn_helmet_resize_height)),
                                           transforms.ToTensor()])
    

    train_data = datasets.ImageFolder(args.dataset_location,
                                      transform=train_transforms)
    

    train_loader = data_utils.DataLoader(train_data, 
                                         batch_size=args.batch_size, 
                                         shuffle=True,
                                         drop_last = True)
    model = Net()
    if torch.cuda.is_available():
        model.cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), 
                          lr = args.learning_rate) 
    
    for epoch in range(args.epoch):  
        start_time = time.time()
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)