Exemple #1
0
def main():
    # Training settings
    from config import args
    from config import train_dataset
    from config import test_dataset
    from config import fun_params

    args = args
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    train_dataset = train_dataset
    test_dataset = test_dataset

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=True, **kwargs)

    target_model = CNN(input_channel=1)
    target_model = torch.nn.DataParallel(target_model).cuda()
    target_optimizer = optim.SGD(target_model.parameters(), lr=args.lr, momentum=args.momentum)

    for epoch in range(args.target_epoch):
        train(args, target_model, device, train_loader, target_optimizer, epoch)
    target_model_path = os.path.join(fun_params['logdir'], 'target_model.pth.tar')
    torch.save({'state_dict': target_model.state_dict()}, target_model_path)
    test_loss, correct = test(target_model, test_loader)
    print('Target model completed. Test loss %f, test accuracy %f' % (test_loss, correct))
    trainer = OuterTrainer(target_model, device, train_dataset, test_dataset, fun_params)
    trainer.train(fun_params['n_process'])
Exemple #2
0
    def test(self, model_path):
        model = CNN(input_channel=1)
        model = torch.nn.DataParallel(model).cuda()
        optimizer = optim.SGD(model.parameters(),
                              lr=self.inner_args['inner_lr'],
                              momentum=self.inner_args['inner_momentum'])
        cross_entropy = SoftCrossEntropy()
        model.train()

        self.target_model.eval()

        eloss = loss_net(self.inner_args['nclass'])
        eloss = torch.nn.DataParallel(eloss).cuda()

        assert os.path.exists(model_path), 'model path is not exist.'
        check_point = torch.load(model_path)
        eloss.load_state_dict(check_point['state_dict'])

        eloss.eval()
        minloss = 10000
        # print('start train model')
        train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.inner_args['inner_batch_size'],
            shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.inner_args['test_batch_size'],
            shuffle=True)

        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.cuda()
            optimizer.zero_grad()

            output = model(data)
            target_output = self.target_model(data)
            target = target.view(-1, 1)

            target = torch.zeros(data.shape[0],
                                 10).scatter_(1, target.long(), 1.0)
            varInput = torch.cat((target.cuda(), target_output), 1)
            # print(varInput.shape, target_output.shape)
            soft_target = eloss.forward(varInput)
            # print(soft_target)
            loss = cross_entropy(output, soft_target)
            if self.save_model:
                self.log.log_tabular('epoch', batch_idx)
                self.log.log_tabular('loss', loss.item())
                self.log.dump_tabular()
                if loss < minloss:
                    torch.save(
                        {
                            'epoch': batch_idx + 1,
                            'state_dict': model.state_dict(),
                            'best_loss': minloss,
                            'optimizer': optimizer.state_dict()
                        }, self.log.path_model)
            loss.backward()
            optimizer.step()
Exemple #3
0
    def train(self, epoch, pool_rank, phi):
        model = CNN(input_channel=1)
        model = torch.nn.DataParallel(model).cuda()
        optimizer = optim.SGD(model.parameters(),
                              lr=self.inner_lr,
                              momentum=self.inner_momentum)
        cross_entropy = SoftCrossEntropy()
        model.train()

        self.target_model.eval()

        eloss = loss_net(self.nclass)
        eloss = torch.nn.DataParallel(eloss).cuda()
        for key in eloss.state_dict().keys():
            eloss.state_dict()[key] = eloss.state_dict()[key] + phi[key]
        eloss.eval()
        minloss = 10000
        # print('start train model')
        for e in range(self.inner_epoch_freq):
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data = data.cuda()
                optimizer.zero_grad()

                output = model(data)
                target_output = self.target_model(data)
                target = target.view(-1, 1)

                target = torch.zeros(data.shape[0],
                                     10).scatter_(1, target.long(), 1.0)
                varInput = torch.cat((target.cuda(), target_output), 1)
                # print(varInput.shape, target_output.shape)
                soft_target = eloss.forward(varInput)
                # print(soft_target)
                loss = cross_entropy(output, soft_target)
                if self.save_model:
                    if epoch % 20 == 0:
                        self.log.log_tabular('epoch', batch_idx)
                        self.log.log_tabular('loss', loss.item())
                    if loss < minloss:
                        torch.save(
                            {
                                'epoch': batch_idx + 1,
                                'state_dict': model.state_dict(),
                                'best_loss': minloss,
                                'optimizer': optimizer.state_dict()
                            }, self.log.path_model)
                loss.backward()
                optimizer.step()
                # print('[pool: %d] epoch %d, loss: %f' % (pool_rank, epoch, loss.item()))

            if epoch % 20 == 0 and self.save_model:
                self.log.dump_tabular()

        accuracy = self.test(model)
        return accuracy
Exemple #4
0
class Trainer:
    def __init__(self, num_epochs=5, batch_size=4, lr=0.001):
        self.num_epochs = num_epochs

        trainset = CharacterDataset()
        self.trainloader = torch.utils.data.DataLoader(trainset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=2)

        self.net = CNN()
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.net.parameters(), lr=lr, momentum=0.9)

    def train(self):
        for epoch in range(self.num_epochs):
            for i, data in enumerate(self.trainloader, 0):
                # get the inputs
                inputs, labels = data

                # zero the parameter gradients
                self.optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.net(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

    def test_model_against_trainset(self):
        correct = 0
        total = 0

        with torch.no_grad():
            for data in self.trainloader:
                images, labels = data

                outputs = self.net(images)
                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        return correct / total

    def save_model(self, filepath='res/model.pt'):
        torch.save(self.net.state_dict(), filepath)
Exemple #5
0
    train_dataset = IQADataset(dataset, config, index, "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=0)
    val_dataset = IQADataset(dataset, config, index, "val")
    val_loader = torch.utils.data.DataLoader(val_dataset)

    test_dataset = IQADataset(dataset, config, index, "test")
    test_loader = torch.utils.data.DataLoader(test_dataset)

    model = CNN().to(device)

    criterion = nn.L1Loss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay)

    best_SROCC = -1

    for epoch in range(args.epochs):
        #train
        model.train()
        LOSS = 0
        for i, (patches, label) in enumerate(train_loader):
            patches = patches.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            outputs = model(patches)
Exemple #6
0
    cv_start = '2019-03-02 01:30:00'
    cv_end = '2019-12-01 15:00:00'

    test_start = '2019-12-02 01:30:00'
    test_end = '2021-04-20 15:00:00'

    data_processor = DataProcessor()
    train_loader = data_processor.get_data(train_start, train_end)
    cv_loader = data_processor.get_data(cv_start, cv_end)
    test_loader = data_processor.get_data(test_start, test_end)

    model = CNN(transform='GAF')
    model = model.double()
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    max_epochs = 100
    accuracy_list = []
    for t in range(max_epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_loader, model, loss_fn, optimizer)
        acc = test_loop(cv_loader, model, loss_fn)
        accuracy_list.append(acc)

        if len(accuracy_list) > 8:
            print(accuracy_list[-8:])

        if len(accuracy_list) > 8 and max(
                accuracy_list[-8:]) == accuracy_list[-8]:
            break
Exemple #7
0
training_labels = th.unsqueeze(training_labels, 1)
training_set = TensorDataset(training_data, training_labels)
training_loader = DataLoader(training_set, args.batch_size)

validation_data, validation_labels = th.from_numpy(
    validation_data), th.from_numpy(validation_labels)
validation_data = validation_data.view(-1, 1, 28, 28)
validation_labels = th.unsqueeze(validation_labels, 1)
validation_set = TensorDataset(validation_data, validation_labels)
validation_loader = DataLoader(validation_set, args.batch_size)

model = CNN()
if cuda:
    model.cuda()
criterion = nn.L1Loss()
optimizer = Adam(model.parameters(), lr=1e-3)

for epoch in range(args.n_epochs):
    for iteration, batch in enumerate(training_loader):
        data, labels = batch
        if cuda:
            data, labels = data.cuda(), labels.cuda()

        noisy_labels = th.zeros(labels.size())
        noisy_labels.copy_(labels)
        n_noises = int(args.noise * labels.size()[0])
        noise = th.from_numpy(np.random.choice(np.arange(1, 10), n_noises))
        if cuda:
            noise = noise.cuda()
        if n_noises > 0:
            noisy_labels[:n_noises] = (labels[:n_noises] + noise) % 10
Exemple #8
0
                                        shuffle=False,
                                        num_workers=4)

    print("Train : ", len(trains.dataset))
    print("Test : ", len(tests.dataset))

    dropmax_cnn = DropMaxCNN(n_classes=n_classes).to(device)
    dropmax_loss = DropMax(device)
    adam = optim.Adam(dropmax_cnn.parameters(),
                      lr=0.0005,
                      betas=(0.5, 0.999),
                      weight_decay=1e-4)

    cnn = CNN(n_classes=n_classes).to(device)
    ce_loss = nn.CrossEntropyLoss()
    adam2 = optim.Adam(cnn.parameters(),
                       lr=0.0005,
                       betas=(0.5, 0.999),
                       weight_decay=1e-4)

    for epoch in range(30):
        dropmax_cnn.train()
        print("Epoch : ", epoch)
        for i, (img, target) in enumerate(trains):
            img = img.to(device)
            one_hot = torch.zeros(img.shape[0], n_classes)
            one_hot.scatter_(1, target.unsqueeze(dim=1), 1)
            one_hot = one_hot.to(device)

            o, p, r, q = dropmax_cnn(img)
            loss = dropmax_loss(o, p, r, q, one_hot)
training_data, training_labels = th.from_numpy(training_data), th.from_numpy(
    training_labels)
training_data = training_data.view(-1, 1, 28, 28)
training_set = TensorDataset(training_data, training_labels)
training_loader = DataLoader(training_set, args.batch_size)

validation_data, validation_labels = th.from_numpy(
    validation_data), th.from_numpy(validation_labels)
validation_data = validation_data.view(-1, 1, 28, 28)
validation_set = TensorDataset(validation_data, validation_labels)
validation_loader = DataLoader(validation_set, args.batch_size)

model = CNN()
model.cuda()
optimizer = SGD(model.parameters(), lr=1e-2, momentum=0.9)

for epoch in range(args.n_epochs):
    for iteration, batch in enumerate(training_loader):
        data, labels = batch
        data, labels = data.cuda(), labels.cuda()
        n_noises = int(args.noise * labels.size()[0])
        noise = np.random.choice(np.arange(10), n_noises)
        labels[:n_noises] = th.from_numpy(noise)
        data, labels = Variable(data), Variable(labels)
        data = model(data)
        loss = F.nll_loss(F.log_softmax(data), labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Exemple #10
0
def main():
    batch_size = 100
    train_data, val_data, test_data = create_train_val_test_split(batch_size)
    data_feeder = DataFeeder(train_data,
                             preprocess_workers=1,
                             cuda_workers=1,
                             cpu_size=10,
                             cuda_size=10,
                             batch_size=batch_size,
                             use_cuda=True,
                             volatile=False)
    data_feeder.start_queue_threads()
    val_data = make_batch(len(val_data),
                          0,
                          val_data,
                          use_cuda=True,
                          volatile=True)
    test_data = make_batch(len(test_data),
                           0,
                           test_data,
                           use_cuda=True,
                           volatile=True)

    cnn = CNN().cuda()
    fcc = FCC().cuda()

    optimizer_cnn = optim.SGD(cnn.parameters(),
                              lr=0.001,
                              momentum=0.9,
                              weight_decay=0.00001)
    optimizer_fcc = optim.SGD(fcc.parameters(),
                              lr=0.001,
                              momentum=0.9,
                              weight_decay=0.00001)

    cnn_train_loss = Logger("cnn_train_losses.txt")
    cnn_val_loss = Logger("cnn_val_losses.txt")
    cnn_val_acc = Logger("cnn_val_acc.txt")
    fcc_train_loss = Logger("fcc_train_losses.txt")
    fcc_val_loss = Logger("fcc_val_losses.txt")
    fcc_val_acc = Logger("fcc_val_acc.txt")

    #permute = Variable(torch.from_numpy(np.random.permutation(28*28)).long().cuda(), requires_grad=False)
    permute = None

    for i in range(100001):
        images, labels = data_feeder.get_batch()
        train(cnn, optimizer_cnn, images, labels, i, cnn_train_loss, permute)
        train(fcc, optimizer_fcc, images, labels, i, fcc_train_loss, permute)
        if i % 100 == 0:
            print(i)
            evaluate_acc(batch_size, cnn, val_data, i, cnn_val_loss,
                         cnn_val_acc, permute)
            evaluate_acc(batch_size, fcc, val_data, i, fcc_val_loss,
                         fcc_val_acc, permute)
        if i in [70000, 90000]:
            decrease_lr(optimizer_cnn)
            decrease_lr(optimizer_fcc)
        if i % 1000 == 0:
            torch.save(cnn.state_dict(),
                       "savedir/cnn_it" + str(i // 1000) + "k.pth")
            torch.save(fcc.state_dict(),
                       "savedir/fcc_it" + str(i // 1000) + "k.pth")

    data_feeder.kill_queue_threads()

    import evaluate
    evaluate.main(permute)
Exemple #11
0
#learning_rate
lr = 1e-4
num_epochs = int(args.epochs)

similarity_dims = 1000
optimizer = 'SGD'
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
OPTIMIZER = {'Adam': optim.Adam, 'SGD': optim.SGD}

logger.info('Build the model')
#################################################use resnet 50
model = CNN(int(args.cnn_size))
#testloader = DataLoader(test_set, batch_size=10, shuffle=False, num_workers=2)
optimizer = OPTIMIZER[optimizer](model.parameters(), lr=lr)

# Save arguments used to create model for restoring the model later

similarity_margin = 0.03


def test_model(model, test_dl):

    model.eval()
    running_loss = 0.0
    running_corrects = 0
    iteration = 0

    for data in test_dl:
Exemple #12
0
    tests = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transform),
        batch_size=32, shuffle=False, num_workers=4
    )

    print("Train : ", len(trains.dataset))
    print("Test : ", len(tests.dataset))

    dropmax_cnn = DropMaxCNN(n_classes=n_classes).to(device)
    dropmax_loss = DropMax(device)
    adam = optim.Adam(dropmax_cnn.parameters(), lr=0.0005, betas=(0.5, 0.999), weight_decay=1e-4)

    cnn = CNN(n_classes=n_classes).to(device)
    ce_loss = nn.CrossEntropyLoss()
    adam2 = optim.Adam(cnn.parameters(), lr=0.0005, betas=(0.5, 0.999), weight_decay=1e-4)

    for epoch in range(30):
        dropmax_cnn.train()
        print("Epoch : ", epoch)
        for i, (img, target) in enumerate(trains):
            img = img.to(device)
            one_hot = torch.zeros(img.shape[0], n_classes)
            one_hot.scatter_(1, target.unsqueeze(dim=1), 1)
            one_hot = one_hot.to(device)

            o, p, r, q = dropmax_cnn(img)
            loss = dropmax_loss(o, p, r, q, one_hot)

            adam.zero_grad()
            loss.backward()
Exemple #13
0
def training_model(num):
    # torch.manual_seed(RANDOM_SEED_NUM)

    # Load the dataset
    X_training, Y_training, X_dev, Y_dev, X_test, Y_test = data_load()
    print(len(X_training))

    modellist = []
    training_losseslist = []
    test_accuracieslist = []
    training_losses = []
    test_accuracies = []
    y_truelist = []
    y_predlist = []
    y_true = []
    y_pred = []

    for t in range(TEST_NUM):
        try:
            # Create the model
            model = CNN().to(device)

            # Loss and optimizer
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

            training_losses = []
            test_accuracies = []

            # Train the model
            train_dataset = TensorDataset(X_training, Y_training)
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            dev_dataset = TensorDataset(X_dev, Y_dev)
            dev_loader = DataLoader(dev_dataset)
            test_dataset = TensorDataset(X_test, Y_test)
            test_loader = DataLoader(test_dataset)

            # total_step = len(train_loader)  # how many batches for one epoch
            for epoch in range(num_epochs):
                for i, (inputs, labels) in enumerate(train_loader):
                    inputs = inputs.reshape(-1, sequence_length, input_size).to(device)
                    inputs = inputs.permute(0,2,1).to(device)
                    labels = labels.reshape(-1).to(device)

                    # Forward pass
                    outputs = model(inputs)
                    training_loss = criterion(outputs, labels)

                    # Backward and optimize
                    optimizer.zero_grad()
                    training_loss.backward()
                    optimizer.step()

                if (epoch + 1) % 2 == 0:
                    print('Test [{}/{}], Epoch [{}/{}], Loss: {:.4f}'
                          .format(t + 1, TEST_NUM, epoch + 1, num_epochs, training_loss.item()))

                    # Get the value of loss
                    training_losses.append(training_loss.item())

                    # Test the model on dev set
                    with torch.no_grad():
                        y_true = []
                        y_pred = []
                        correct = 0
                        total = 0
                        for j, (inputs, labels) in enumerate(dev_loader):
                            inputs = inputs.reshape(-1, sequence_length, input_size).to(device)
                            inputs = inputs.permute(0, 2, 1).to(device)
                            labels = labels.reshape(-1).to(device)
                            outputs = model(inputs)
                            _, predicted = torch.max(outputs.data, 1)
                            total += labels.size(0)
                            correct += (predicted == labels).sum().item()
                            y_true.append(labels.item())
                            y_pred.append(predicted.item())
                        test_accuracies.append(correct/total)

        except KeyboardInterrupt:
            print('Stop!')

        modellist.append(model)
        training_losseslist.append(training_losses)
        test_accuracieslist.append(test_accuracies)
        y_truelist.append(y_true)
        y_predlist.append(y_pred)

    # Print accuracy of the model
    accuracy = []
    for item in test_accuracieslist:
        accuracy.append(item[-1] * 100)
    max_accuracy = max(accuracy)
    print('Dev accuracy of the No.{} model on dev action samples: {} %'.format(num+1, max_accuracy))

    # Show or save the graph of variance and bias analysis, and confusion matrix graph
    variance_and_bias_analysis(training_losseslist, test_accuracieslist)
    save('trials' + str(num+1) + '_loss_accuracy' + '.png')
    plot_confusion_matrix(y_truelist, y_predlist, LABELS)
    save('trials_' + str(num+1) + '_confusion_matrix' + '.png')

    return max_accuracy, modellist[accuracy.index(max_accuracy)], test_loader
Exemple #14
0
def train(train_file_path,
          val_file_path,
          in_channels,
          num_class,
          batch_norm,
          dropout,
          n_epochs,
          batch_size,
          lr,
          momentum,
          weight_decay,
          optim_type,
          ckpt_path,
          max_ckpt_save_num,
          ckpt_save_interval,
          val_interval,
          resume,
          device='cpu'):
    '''
    The main training procedure
    ----------------------------
    :param train_file_path: file list of training image paths and labels
    :param val_file_path: file list of validation image paths and labels
    :param in_channels: channel number of image
    :param num_class: number of classes, in this task it is 26 English letters
    :param batch_norm: whether to use batch normalization in convolutional layers and linear layers
    :param dropout: dropout ratio of dropout layer which ranges from 0 to 1
    :param n_epochs: number of training epochs
    :param batch_size: batch size of training
    :param lr: learning rate
    :param momentum: only used if optim_type == 'sgd'
    :param weight_decay: the factor of L2 penalty on network weights
    :param optim_type: optimizer, which can be set as 'sgd', 'adagrad', 'rmsprop', 'adam', or 'adadelta'
    :param ckpt_path: path to save checkpoint models
    :param max_ckpt_save_num: maximum number of saving checkpoint models
    :param ckpt_save_interval: intervals of saving checkpoint models, e.g., if ckpt_save_interval = 2, then save checkpoint models every 2 epochs
    :param val_interval: intervals of validation, e.g., if val_interval = 5, then do validation after each 5 training epochs
    :param resume: path to resume model
    :param device: 'cpu' or 'cuda', we can use 'cpu' for our homework if GPU with cuda support is not available
    '''

    # construct training and validation data loader
    train_loader = dataLoader(train_file_path,
                              norm_size=(32, 32),
                              batch_size=batch_size)
    val_loader = dataLoader(val_file_path, norm_size=(32, 32), batch_size=1)

    model = CNN(in_channels, num_class, batch_norm, dropout)

    # put the model on CPU or GPU
    model = model.to(device)

    # define loss function and optimizer
    loss_func = nn.CrossEntropyLoss()

    if optim_type == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif optim_type == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(),
                                  lr,
                                  weight_decay=weight_decay)
    elif optim_type == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  lr,
                                  weight_decay=weight_decay)
    elif optim_type == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr,
                               weight_decay=weight_decay)
    elif optim_type == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(),
                                   lr,
                                   weight_decay=weight_decay)
    else:
        print(
            '[Error] optim_type should be one of sgd, adagrad, rmsprop, adam, or adadelta'
        )
        raise NotImplementedError

    if resume is not None:
        print('[Info] resuming model from %s ...' % resume)
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])

    # training
    # to save loss of each training epoch in a python "list" data structure
    losses = []
    # to save accuracy on validation set of each training epoch in a python "list" data structure
    accuracy_list = []
    val_epochs = []

    print('training...')
    for epoch in range(n_epochs):
        # set the model in training mode
        model.train()

        # to save total loss in one epoch
        total_loss = 0.

        for step, (input,
                   label) in enumerate(train_loader):  # get a batch of data

            # set data type and device
            input, label = input.type(torch.float).to(device), label.type(
                torch.long).to(device)

            # clear gradients in the optimizer
            optimizer.zero_grad()

            # run the model which is the forward process
            out = model(input)

            # compute the CrossEntropy loss, and call backward propagation function
            loss = loss_func(out, label)
            loss.backward()

            # update parameters of the model
            optimizer.step()

            # sum up of total loss, loss.item() return the value of the tensor as a standard python number
            # this operation is not differentiable
            total_loss += loss.item()

        # average of the total loss for iterations
        avg_loss = total_loss / len(train_loader)
        losses.append(avg_loss)

        # evaluate model on validation set
        if (epoch + 1) % val_interval == 0:
            val_accuracy = eval_one_epoch(model, val_loader, device)
            accuracy_list.append(val_accuracy)
            val_epochs.append(epoch)
            print(
                'Epoch {:02d}: loss = {:.3f}, accuracy on validation set = {:.3f}'
                .format(epoch + 1, avg_loss, val_accuracy))

        if (epoch + 1) % ckpt_save_interval == 0:
            # get info of all saved checkpoints
            ckpt_list = glob.glob(os.path.join(ckpt_path, 'ckpt_epoch_*.pth'))
            # sort checkpoints by saving time
            ckpt_list.sort(key=os.path.getmtime)
            # remove surplus ckpt file if the number is larger than max_ckpt_save_num
            if len(ckpt_list) >= max_ckpt_save_num:
                for cur_file_idx in range(
                        0,
                        len(ckpt_list) - max_ckpt_save_num + 1):
                    os.remove(ckpt_list[cur_file_idx])

            # save model parameters in a file
            ckpt_name = os.path.join(ckpt_path,
                                     'ckpt_epoch_%d.pth' % (epoch + 1))
            save_dict = {
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'configs': {
                    'in_channels': in_channels,
                    'num_class': num_class,
                    'batch_norm': batch_norm,
                    'dropout': dropout
                }
            }

            torch.save(save_dict, ckpt_name)
            print('Model saved in {}\n'.format(ckpt_name))

    plot(losses, accuracy_list, val_epochs, ckpt_path)