Пример #1
0
def main():
    # batch_size = 100
    batch_size = 1
    print("here")
    sk_root = '../256x256/sketch/tx_000000000000'
    sk_root = '../256x256/photo/tx_000000000000'
    sk_root ='../test_pair/sketch'
    in_size = 225
    in_size = 224
    train_dataset = DataSet.ImageDataset(sk_root, transform=Compose([Resize(in_size), ToTensor()]))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, num_workers=os.cpu_count(),
                         shuffle=True, drop_last=True)

    test_dataset = DataSet.ImageDataset(sk_root, transform=Compose([Resize(in_size), ToTensor()]),train=False)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, num_workers=os.cpu_count(),
                         shuffle=True, drop_last=True)

    num_class = len(train_dataset.classes)
    embed_size = -1
    model = getResnet(num_class=num_class, pretrain=True)
    model.train()
    if torch.cuda.is_available():
        model = model.cuda()
    
    crit = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters())
    # optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    # Tensorboard stuff
    # writer = tb.SummaryWriter('./logs')



    count = 0
    epochs = 2
    prints_interval = 1
    max_chpt = 3
    max_acu = -1
    chpt_num = 0
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output

        return hook


    model.avgpool.register_forward_hook(get_activation('avgpool'))
    for e in range(epochs):
        print('epoch',e,'started')
        avg_loss = 0
        for i, (X, Y) in enumerate(train_dataloader):

            activation = {}




            if torch.cuda.is_available():
                X, Y = X.cuda(), Y.cuda()

            optim.zero_grad()
            to_image = transforms.ToPILImage()
            output = model(X)
            print(activation['avgpool'].shape)
            loss = crit(output, Y)
            avg_loss += loss.item()
            if i == 0:
                print(loss)
            if i % prints_interval == 0:
                print(f'[Training] {i}/{e}/{epochs} -> Loss: {avg_loss/(i+1)}')
                # writer.add_scalar('train-loss', loss.item(), count)
            loss.backward()

            optim.step()

            count += 1
        print('epoch',e,'loss',avg_loss/len(train_dataloader))
        correct, total, accuracy= 0, 0, 0
        # model.eval()
        for i, (X, Y) in enumerate(test_dataloader):

            if torch.cuda.is_available():
                X, Y = X.cuda(), Y.cuda()
            output = model(X)
            _, predicted = torch.max(output, 1)
            total += Y.size(0)
            correct += (predicted == Y).sum().item()


        accuracy = (correct / total) * 100

        print(f'[Testing] -/{e}/{epochs} -> Accuracy: {accuracy} %',total,correct)
        # model.train()
        if accuracy >= max_acu:
            path = 'checkpoint'+str(chpt_num)+'.pt'
            max_acu = accuracy
            chpt_num= (chpt_num+1)%max_chpt
            set_checkpoint(epoch=e,model=model,optimizer=optim,train_loss=avg_loss/len(train_dataloader),accurate=accuracy,path=path)
            path = 'best.pt'
            set_checkpoint(epoch=e,model=model,optimizer=optim,train_loss=avg_loss/len(train_dataloader),accurate=accuracy,path=path)
from model.siamesenet import SiameseNet
from model.resnet import getResnet
import data.dataset as DataSet
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import data.datautil as util
import torchvision
from model.linear import ClassificationNet
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


def load_checkpoint(path, model, optimizer):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    train_loss=checkpoint['train_loss']
    print("Checkpoint loaded", 'epoch', epoch, 'loss', loss)


embedding_size = 2
net1 = getResnet(num_class=embedding_size, pretrain=True)
margin = 10
model = SiameseNet(net1)
path = 'best.pt'
optim = torch.optim.Adam(model.parameters())
load_checkpoint(path,model,optim)
Пример #3
0
def main():
    # batch_size = 100
    batch_size = 2
    print("here")
    sk_root = '../256x256/sketch/tx_000000000000'
    sk_root = '../test'
    in_size = 225
    in_size = 224
    train_dataset = DataSet.ImageDataset(sk_root, transform=Compose([Resize(in_size), ToTensor()]))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, num_workers=os.cpu_count(),
                                  shuffle=True, drop_last=True)
    tmp_root ='../testpair/photo'
    # util.train_test_split(tmp_root,split=(0.8,0.1,0.1))
    sketch_root = '../testpair/sketch'
    train_dataset = DataSet.PairedDataset(photo_root=tmp_root,sketch_root=sketch_root,transform=Compose([Resize(in_size), ToTensor()]))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, num_workers=os.cpu_count(),
                                  shuffle=True, drop_last=True)
    test_dataset = DataSet.PairedDataset(photo_root=tmp_root,sketch_root=sketch_root,transform=Compose([Resize(in_size), ToTensor()]),train=True)
    test_dataloader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, num_workers=os.cpu_count(),
                                  shuffle=True, drop_last=True)

    model = SketchANet(num_classes=3)
    model = Net()
    crit = torch.nn.CrossEntropyLoss()
    net1 = getResnet(num_class=100)
    margin = 1
    model = SiameseNet(net1,net1)
    crit = ContrastiveLoss(margin)
    if torch.cuda.is_available():
        model = model.cuda()



    optim = torch.optim.Adam(model.parameters())
    # optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    # Tensorboard stuff
    # writer = tb.SummaryWriter('./logs')
    to_image = transforms.ToPILImage()
    count = 0
    epochs = 200
    prints_interval = 1
    for e in range(epochs):
        print('epoch', e, 'started')
        for i, (X, Y) in enumerate(train_dataloader):

            if torch.cuda.is_available():
                X, Y = X.cuda(), Y.cuda()
            sketch,photo = X
            Y = (Y != train_dataset.class_to_index['unmatched'])
            # for i in range(sketch.shape[0]):
            #     image =to_image(sketch[i])
            #     util.showImage(image)
            #     image =to_image(photo[i])
            #     util.showImage(image)
            #     print(Y)
            optim.zero_grad()
            #
            # image = to_image(X[0])
            # util.showImage(image)
            # print(train_dataset.class_to_idx)
            # print(Y)
            output = model(*X)
            # print(output,Y)

            loss = crit(*output, Y)

            if i % prints_interval == 0:
                print(f'[Training] {i}/{e}/{epochs} -> Loss: {loss.item()}')
                # writer.add_scalar('train-loss', loss.item(), count)

            # to_image = transforms.ToPILImage()
            # image = to_image(X[0])
            # util.showImage(image)
            # print(train_dataset.class_to_idx)
            # print(Y)

            loss.backward()
            optim.step()

            count += 1
        print('epoch', e, 'loss', loss.item())
        correct, total, accuracy = 0, 0, 0
        model.eval()
        # print(f'[Testing] -/{e}/{epochs} -> Accuracy: {accuracy} %', total, correct)
        model.train()
def main():
    # batch_size = 100
    batch_size = 1
    balanced = False
    print("Start Training")

    # sk_root ='../test'
    in_size = 225
    in_size = 224
    tmp_root = '../test_pair/photo'
    sketch_root = '../test_pair/sketch'
    # tmp_root = '../256x256/photo/tx_000000000000'
    # sketch_root = '../256x256/sketch/tx_000000000000'

    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_dataset = DataSet.PairedDataset(photo_root=tmp_root,
                                          sketch_root=sketch_root,
                                          transform=Compose(
                                              [Resize(in_size),
                                               ToTensor()]),
                                          balanced=balanced)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  num_workers=os.cpu_count(),
                                  shuffle=True,
                                  drop_last=True)
    test_dataset = DataSet.PairedDataset(photo_root=tmp_root,
                                         sketch_root=sketch_root,
                                         transform=transform,
                                         train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 pin_memory=True,
                                 num_workers=os.cpu_count(),
                                 shuffle=True,
                                 drop_last=True)

    embedding_size = 512
    margin = 1
    num_class = len(train_dataset.classes) - 1
    photo_net = getResnet(num_class=num_class,
                          pretrain=True,
                          feature_extract=True)

    for param in photo_net.parameters():
        param.requires_grad = False

    sketch_net = getResnet(num_class=num_class,
                           pretrain=True,
                           feature_extract=False)
    softmax_loss = SoftMax(embed_size=embedding_size, num_class=num_class)
    optim = torch.optim.Adam(
        list(sketch_net.parameters()) + list(softmax_loss.parameters()))
    optim = torch.optim.Adam(list(sketch_net.parameters()))
    model = ParallelNet(sketch_net=sketch_net, photo_net=photo_net)
    print(sketch_net)
    contrastive_loss = ContrastiveLoss(margin)

    cross = torch.nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        model = model.cuda()

    # optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    # Tensorboard stuff
    # writer = tb.SummaryWriter('./logs')

    epochs = 100
    prints_interval = 1
    max_chpt = 3
    min_loss = 100000
    chpt_num = 0
    for e in range(epochs):
        print('epoch', e, 'started')
        avg_loss = 0
        for i, (X, Y) in enumerate(train_dataloader):
            one = torch.ones(Y[0].shape)
            zero = torch.zeros(Y[0].shape)
            if torch.cuda.is_available():
                X, Y = (X[0].cuda(), X[1].cuda()), (Y[0].cuda(), Y[1].cuda(),
                                                    Y[2].cuda())
                one, zero = one.cuda(), zero.cuda()
            optim.zero_grad()

            sketch, photo = X
            (Y, label_s, label_p) = Y
            embedding_sketch = sketch_net(sketch)
            embedding_photo = photo_net(photo)
            loss = cross(embedding_sketch, label_s)
            sloss = 0
            # sloss = softmax_loss(embedding_sketch, label_s)
            # sketch_feature = normalize(embedding_sketch)
            # phtot_feature = normalize(embedding_photo)

            Y = torch.where(Y != train_dataset.class_to_index['unmatched'],
                            one, zero)

            closs = 0
            # closs = contrastive_loss(sketch_feature, phtot_feature, Y)

            # loss = 0.0 * closs + 1* sloss

            avg_loss += loss.item()
            if i % prints_interval == 0:
                print(
                    f'[Training] {i}/{e}/{epochs} -> Loss: {avg_loss / (i + 1)} Contrastive: {closs} SoftMax: {sloss}'
                )
            loss.backward()

            optim.step()

        print('epoch', e, 'end', 'Avg loss', avg_loss / len(train_dataloader))
Пример #5
0
def main():
    # batch_size = 100
    batch_size = 1
    balanced = False
    print("here")
    # sk_root ='../test'
    in_size = 225
    in_size = 224
    tmp_root = '../256x256/photo/tx_000000000000'
    sketch_root = '../256x256/sketch/tx_000000000000'
    # tmp_root = '../rendered_256x256/256x256/photo/tx_000000000000'
    # sketch_root = '../rendered_256x256/256x256/sketch/tx_000000000000'
    tmp_root = '../test_pair/photo'
    sketch_root = '../test_pair/sketch'
    train_dataset = DataSet.PairedDataset(photo_root=tmp_root,
                                          sketch_root=sketch_root,
                                          transform=Compose(
                                              [Resize(in_size),
                                               ToTensor()]),
                                          balanced=balanced)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  num_workers=os.cpu_count(),
                                  shuffle=True,
                                  drop_last=True)

    test_dataset = DataSet.ImageDataset(sketch_root,
                                        transform=Compose(
                                            [Resize(in_size),
                                             ToTensor()]),
                                        train=True)
    # print(test_dataset.classes)
    # print(train_dataset.classes)
    # print(test_dataset.class_to_idx)
    # print(train_dataset.class_to_index)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 pin_memory=True,
                                 num_workers=os.cpu_count(),
                                 shuffle=True,
                                 drop_last=True)

    num_class = len(train_dataset.classes)
    embedding_size = 200
    net1 = getResnet(num_class=embedding_size, pretrain=True)
    model = SiaClassNet(net1, embedding_size, num_class)

    method = "classify"
    crit = torch.nn.CrossEntropyLoss()
    model.train()

    if torch.cuda.is_available():
        model = model.cuda()

    optim = torch.optim.Adam(model.parameters())
    # optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    # Tensorboard stuff
    # writer = tb.SummaryWriter('./logs')

    count = 0
    epochs = 200
    prints_interval = 1
    max_chpt = 3
    max_loss = -1
    for e in range(epochs):
        print('epoch', e, 'started')
        avg_loss = 0
        for i, (X, Y) in enumerate(train_dataloader):

            if torch.cuda.is_available():
                X, Y = (X[0].cuda(), X[1].cuda()), (Y[0].cuda(), Y[1].cuda(),
                                                    Y[2].cuda)
            sketch, photo = X
            optim.zero_grad()
            to_image = transforms.ToPILImage()
            #output = model(*X)
            output = model(sketch, sketch)
            (Y, label_s, label_p) = Y
            # loss = crit(output, Y)
            loss = crit(output, label_s)
            avg_loss += loss.item()
            if i % prints_interval == 0:
                print(output, label_s)
                print(f'[Training] {i}/{e}/{epochs} -> Loss: {avg_loss/(i+1)}')
                # writer.add_scalar('train-loss', loss.item(), count)
            loss.backward()

            optim.step()

            count += 1
        print('epoch', e, 'Avg loss', avg_loss / len(train_dataloader))

        eval_accu(test_dataloader, model, e, epochs)
def main():
    # batch_size = 100
    batch_size = 100
    balanced = False
    print("Start Training")
    sk_root = '../256x256/sketch/tx_000000000000'
    # sk_root ='../test'
    in_size = 225
    in_size = 224
    tmp_root = '../test_pair/photo'
    sketch_root = '../test_pair/sketch'
    tmp_root = '../256x256/photo/tx_000000000000'
    sketch_root = '../256x256/sketch/tx_000000000000'
    train_dataset = DataSet.PairedDataset(photo_root=tmp_root,
                                          sketch_root=sketch_root,
                                          transform=Compose(
                                              [Resize(in_size),
                                               ToTensor()]),
                                          balanced=balanced)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  num_workers=os.cpu_count(),
                                  shuffle=True,
                                  drop_last=True)
    test_dataset = DataSet.PairedDataset(photo_root=tmp_root,
                                         sketch_root=sketch_root,
                                         transform=Compose(
                                             [Resize(in_size),
                                              ToTensor()]),
                                         train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 pin_memory=True,
                                 num_workers=os.cpu_count(),
                                 shuffle=True,
                                 drop_last=True)

    num_class = len(train_dataset.classes)
    embedding_size = 10242
    embedding_size = 1024
    embedding_size = 512
    net1 = getResnet(num_class=embedding_size, pretrain=True)
    margin = 1
    model = SiameseNet(net1)

    method = 'metric'
    crit = ContrastiveLoss(margin)
    model.train()
    if torch.cuda.is_available():
        model = model.cuda()

    optim = torch.optim.Adam(model.parameters())
    # optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    # Tensorboard stuff
    # writer = tb.SummaryWriter('./logs')

    count = 0
    epochs = 100
    prints_interval = 100
    prints_interval = 100
    max_chpt = 3
    min_loss = 100000
    chpt_num = 0
    for e in range(epochs):
        print('epoch', e, 'started')
        avg_loss = 0
        for i, (X, Y) in enumerate(train_dataloader):
            one = torch.ones(Y[0].shape)
            zero = torch.zeros(Y[0].shape)
            if torch.cuda.is_available():
                X, Y = (X[0].cuda(), X[1].cuda()), (Y[0].cuda(), Y[1].cuda(),
                                                    Y[2].cuda())
                one, zero = one.cuda(), zero.cuda()
            output = model(*X)
            # print(output,Y)
            sketch, photo = X
            #print(sketch.shape)
            optim.zero_grad()
            to_image = transforms.ToPILImage()
            output = model(*X)
            #print(output[0])
            (Y, label_s, label_p) = Y
            Y = torch.where(Y != train_dataset.class_to_index['unmatched'],
                            one, zero)
            loss = crit(*output, Y)
            avg_loss += loss.item()
            if i % prints_interval == 0:
                # print(output,Y)
                print(f'[Training] {i}/{e}/{epochs} -> Loss: {avg_loss/(i+1)}')
                # writer.add_scalar('train-loss', loss.item(), count)
            loss.backward()

            optim.step()

            count += 1
        print('epoch', e, 'Avg loss', avg_loss / len(train_dataloader))
        valid_loss = eval_loss(test_dataloader, model, e, epochs, crit,
                               train_dataset)
        if valid_loss < min_loss:
            path = 'checkpoint' + str(chpt_num) + '.pt'
            min_loss = valid_loss
            chpt_num = (chpt_num + 1) % max_chpt
            set_checkpoint(epoch=e,
                           model=model,
                           optimizer=optim,
                           train_loss=avg_loss / len(train_dataloader),
                           loss=valid_loss,
                           path=path)
            path = 'best.pt'
            set_checkpoint(epoch=e,
                           model=model,
                           optimizer=optim,
                           train_loss=avg_loss / len(train_dataloader),
                           loss=valid_loss,
                           path=path)
Пример #7
0
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


features = torch.load('features.pt')
print(features.shape)
sys.end()
dataiter = iter(train_dataloader)
images, labels = dataiter.next()
num_class = len(train_dataset.classes)
embed_size = -1
model = getResnet(num_class=3, pretrain=True)
model.eval()
if torch.cuda.is_available():
    model = model.cuda()

crit = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters())
load_checkpoint('best.pt', model, optim)
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=100,
                                               shuffle=True)
dataiter = iter(train_dataloader)
images, labels = dataiter.next()
print(images.shape)
# get the class labels for each image
class_labels = [train_dataset.classes[lab] for lab in labels]
                # print(filepath)
                # print(input,input.shape)
                y_pred = model(image_var)
                # print(y_pred.shape)
                smax = nn.Softmax(1)
                smax_out = smax(y_pred)
            # 3.4 save probability to csv files
            csv_map["filename"].extend(filepath)
            for output in smax_out:
                prob = ";".join([str(i) for i in output.data.tolist()])
                csv_map["probability"].append(prob)
        result = pd.DataFrame(csv_map)
        result["probability"] = result["probability"].map(lambda x: [float(i) for i in x.split(";")])
        for index, row in result.iterrows():
            pred_label = np.argmax(row['probability'])+1
            result_str = '{} {}\r\n'.format(row['filename'], pred_label)

            f.writelines(result_str)

if __name__ == '__main__':
    best_model = torch.load(config.test_model_path)

    model = getResnet(config.test_model_name)

    model.load_state_dict(best_model["state_dict"])

    test(config.test_data_path, model)



def main():
    batch_size = 100
    balanced = False
    print("Start Training")

    # sk_root ='../test'
    in_size = 225
    in_size = 224
    tmp_root = '../test_pair/photo'
    sketch_root = '../test_pair/sketch'
    tmp_root = '../256x256/photo/tx_000000000000'
    sketch_root = '../256x256/sketch/tx_000000000000'

    train_dataset = DataSet.PairedDataset(photo_root=tmp_root,
                                          sketch_root=sketch_root,
                                          transform=Compose(
                                              [Resize(in_size),
                                               ToTensor()]),
                                          balanced=balanced)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  num_workers=os.cpu_count(),
                                  shuffle=True,
                                  drop_last=True)
    test_dataset = DataSet.PairedDataset(photo_root=tmp_root,
                                         sketch_root=sketch_root,
                                         transform=Compose(
                                             [Resize(in_size),
                                              ToTensor()]),
                                         train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 pin_memory=True,
                                 num_workers=os.cpu_count(),
                                 shuffle=True,
                                 drop_last=True)

    num_class = len(train_dataset.classes)
    embed_size = -1
    sketch_net = getResnet(num_class=num_class,
                           pretrain=True,
                           feature_extract=True)
    softmax_loss = SoftMax(embed_size=512, num_class=num_class)
    hinge_loss = ContrastiveLoss(margin=2)
    optim = torch.optim.Adam(
        list(sketch_net.parameters()) + list(softmax_loss.parameters()))
    sketch_net.train()
    photo_net = getResnet(num_class=num_class,
                          pretrain=True,
                          feature_extract=True)
    for param in photo_net.parameters():
        param.requires_grad = False

    if torch.cuda.is_available():
        sketch_net = sketch_net.cuda()
        softmax_loss = softmax_loss.cuda()
        photo_net = photo_net.cuda()
    count = 0
    epochs = 200
    max_chpt = 3
    max_acu = -1
    chpt_num = 0
    activation = {}

    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output

        return hook

    for e in range(epochs):
        print('epoch', e, 'Start')
        (avg_loss, avg_class_loss, avg_hinge_loss,
         accuracy) = eval_model(e,
                                epochs,
                                sketch_net,
                                photo_net,
                                softmax_loss,
                                hinge_loss,
                                [train_dataloader, test_dataloader],
                                optim,
                                train=True)
        print('epoch', e, 'End')
        (avg_loss, avg_class_loss, avg_hinge_loss,
         accuracy) = eval_model(e,
                                epochs,
                                sketch_net,
                                photo_net,
                                softmax_loss,
                                hinge_loss,
                                [train_dataloader, test_dataloader],
                                optim,
                                train=False)

        if accuracy >= max_acu:
            path = 'checkpoint' + str(chpt_num) + '.pt'
            max_acu = accuracy
            chpt_num = (chpt_num + 1) % max_chpt
            set_checkpoint(epoch=e,
                           model=sketch_net,
                           softmax=softmax_loss,
                           optimizer=optim,
                           train_loss=avg_loss / len(train_dataloader),
                           softmax_loss=avg_class_loss,
                           hinge_loss=avg_hinge_loss,
                           accurate=accuracy,
                           path=path)
            path = 'best.pt'
            set_checkpoint(epoch=e,
                           model=sketch_net,
                           softmax=softmax_loss,
                           optimizer=optim,
                           train_loss=avg_loss / len(train_dataloader),
                           softmax_loss=avg_class_loss,
                           hinge_loss=avg_hinge_loss,
                           accurate=accuracy,
                           path=path)
def main():
    fold = 0
    # 4.1 mkdirs
    if not os.path.exists(config.submit):
        os.mkdir(config.submit)
    if not os.path.exists(config.weights):
        os.mkdir(config.weights)
    if not os.path.exists(config.best_models):
        os.mkdir(config.best_models)
    if not os.path.exists(config.logs):
        os.mkdir(config.logs)
    if not os.path.exists(config.weights + config.model_name + os.sep + str(fold) + os.sep):
        os.makedirs(config.weights + config.model_name + os.sep + str(fold) + os.sep)
    if not os.path.exists(config.best_models + config.model_name + os.sep + str(fold) + os.sep):
        os.makedirs(config.best_models + config.model_name + os.sep + str(fold) + os.sep)
        # 4.2 get model and optimizer
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = getResnet(config.model_name)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)
    # model = torch.nn.DataParallel(model)
    # model.cuda()
    # optimizer = optim.SGD(model.parameters(),lr = config.lr,momentum=0.9,weight_decay=config.weight_decay)
    optimizer = optim.Adam(model.parameters(), lr=config.lr, amsgrad=True, weight_decay=config.weight_decay)
    # criterion = nn.CrossEntropyLoss().cuda()
    criterion = FocalLoss().cuda()
    log = Logger()
    log.open(config.logs + "log_train.txt", mode="a")
    log.write("\n----------------------------------------------- [START %s] %s\n\n" % (
    datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '-' * 51))
    # 4.3 some parameters for  K-fold and restart model
    start_epoch = 0
    best_precision1 = 0
    best_precision_save = 0
    resume = config.resume

    # 4.4 restart the training process
    if resume:
        checkpoint = torch.load(config.best_models + str(fold) + "/model_best.pth.tar")
        start_epoch = checkpoint["epoch"]
        fold = checkpoint["fold"]
        best_precision1 = checkpoint["best_precision1"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])

    # 4.5 get files and split for K-fold dataset
    # 4.5.1 read filesf
    train_ = get_files(config.train_data_path, "train")
    # val_data_list = get_files(config.val_data,"val")
    test_files = get_files(config.test_data_path, "test")

    """ 
    #4.5.2 split
    split_fold = StratifiedKFold(n_splits=3)
    folds_indexes = split_fold.split(X=origin_files["filename"],y=origin_files["label"])
    folds_indexes = np.array(list(folds_indexes))
    fold_index = folds_indexes[fold]
    #4.5.3 using fold index to split for train data and val data
    train_data_list = pd.concat([origin_files["filename"][fold_index[0]],origin_files["label"][fold_index[0]]],axis=1)
    val_data_list = pd.concat([origin_files["filename"][fold_index[1]],origin_files["label"][fold_index[1]]],axis=1)
    """
    train_data_list, val_data_list = train_test_split(train_, test_size=0.15, stratify=train_["label"])
    # 4.5.4 load dataset
    train_dataloader = DataLoader(RemoteDataLoader(train_data_list), batch_size=config.batch_size, shuffle=True,
                                  collate_fn=collate_fn, pin_memory=True)
    val_dataloader = DataLoader(RemoteDataLoader(val_data_list, train=False), batch_size=config.batch_size, shuffle=True,
                                collate_fn=collate_fn, pin_memory=False)
    test_dataloader = DataLoader(RemoteDataLoader(test_files, test=True), batch_size=1, shuffle=False, pin_memory=False)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,"max",verbose=1,patience=3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    # 4.5.5.1 define metrics
    train_losses = AverageMeter()
    train_top1 = AverageMeter()
    train_top2 = AverageMeter()
    valid_loss = [np.inf, 0, 0]

    valid_loss_list = []
    train_losses_list = []
    model.train()
    # logs
    log.write('** start training here! **\n')
    log.write(
        '                           |------------ VALID -------------|----------- TRAIN -------------|------Accuracy------|------------|\n')
    log.write(
        'lr       iter     epoch    | loss   top-1  top-2            | loss   top-1  top-2           |    Current Best    | time       |\n')
    log.write(
        '-------------------------------------------------------------------------------------------------------------------------------\n')
    # 4.5.5 train
    start = timer()
    for epoch in range(start_epoch, config.epochs):
        scheduler.step(epoch)
        # train
        # global iter
        for iter, (input, target) in enumerate(train_dataloader):
            # 4.5.5 switch to continue train process
            model.train()
            input = Variable(input).cuda()
            target = Variable(torch.from_numpy(np.array(target)).long()).cuda()
            # target = Variable(target).cuda()
            output = model(input)
            loss = criterion(output, target)

            precision1_train, precision2_train = accuracy(output, target, topk=(1, 2))
            train_losses.update(loss.item(), input.size(0)-1)
            train_top1.update(precision1_train[0], input.size(0)-1)
            train_top2.update(precision2_train[0], input.size(0)-1)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr = get_learning_rate(optimizer)
            print('\r', end='', flush=True)
            print(
                '%0.4f %5.1f %6.1f        | %0.3f  %0.3f  %0.3f         | %0.3f  %0.3f  %0.3f         |         %s         | %s' % ( \
                    lr, iter / len(train_dataloader) + epoch, epoch,
                    valid_loss[0], valid_loss[1], valid_loss[2],
                    train_losses.avg, train_top1.avg, train_top2.avg, str(best_precision_save),
                    time_to_str((timer() - start), 'min'))
                , end='', flush=True)
            train_losses_list.append([train_losses.avg, train_top1.avg, train_top2.avg])
            train_losses.reset()
            train_top1.reset()
            train_top2.reset()
        # evaluate
        lr = get_learning_rate(optimizer)
        # evaluate every half epoch
        valid_loss = evaluate(val_dataloader, model, criterion)
        valid_loss_list.append(valid_loss_list)

        is_best = valid_loss[1] > best_precision1
        best_precision1 = max(valid_loss[1], best_precision1)
        try:
            best_precision_save = best_precision1.cpu().data.numpy()
        except:
            pass
        save_checkpoint({
            "epoch": epoch + 1,
            "model_name": config.model_name,
            "state_dict": model.state_dict(),
            "best_precision1": best_precision1,
            "optimizer": optimizer.state_dict(),
            "fold": fold,
            "valid_loss": valid_loss,
        }, is_best, fold)
        save_loss_npy('all_train_loss_{}.npy'.format(epoch+1), train_losses_list)
        save_loss_npy('all_val_loss_{}.npy'.format(epoch+1), valid_loss_list)
        # adjust learning rate
        # scheduler.step(valid_loss[1])
        print("\r", end="", flush=True)
        log.write(
            '%0.4f %5.1f %6.1f        | %0.3f  %0.3f  %0.3f          | %0.3f  %0.3f  %0.3f         |         %s         | %s' % ( \
                lr, 0 + epoch, epoch,
                valid_loss[0], valid_loss[1], valid_loss[2],
                train_losses.avg, train_top1.avg, train_top2.avg, str(best_precision_save),
                time_to_str((timer() - start), 'min'))
            )
        log.write('\n')
        time.sleep(0.01)
    best_model = torch.load(
        config.best_models + os.sep + config.model_name + os.sep + str(fold) + os.sep + 'model_best.pth.tar')
    # covert loss list to np, n*3
    save_loss_npy('all_train_loss.npy', train_losses_list)
    save_loss_npy('all_val_loss.npy', valid_loss_list)