Exemplo n.º 1
0
    def test(self, model_path):
        model = CNN(input_channel=1)
        model = torch.nn.DataParallel(model).cuda()
        optimizer = optim.SGD(model.parameters(),
                              lr=self.inner_args['inner_lr'],
                              momentum=self.inner_args['inner_momentum'])
        cross_entropy = SoftCrossEntropy()
        model.train()

        self.target_model.eval()

        eloss = loss_net(self.inner_args['nclass'])
        eloss = torch.nn.DataParallel(eloss).cuda()

        assert os.path.exists(model_path), 'model path is not exist.'
        check_point = torch.load(model_path)
        eloss.load_state_dict(check_point['state_dict'])

        eloss.eval()
        minloss = 10000
        # print('start train model')
        train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.inner_args['inner_batch_size'],
            shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.inner_args['test_batch_size'],
            shuffle=True)

        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.cuda()
            optimizer.zero_grad()

            output = model(data)
            target_output = self.target_model(data)
            target = target.view(-1, 1)

            target = torch.zeros(data.shape[0],
                                 10).scatter_(1, target.long(), 1.0)
            varInput = torch.cat((target.cuda(), target_output), 1)
            # print(varInput.shape, target_output.shape)
            soft_target = eloss.forward(varInput)
            # print(soft_target)
            loss = cross_entropy(output, soft_target)
            if self.save_model:
                self.log.log_tabular('epoch', batch_idx)
                self.log.log_tabular('loss', loss.item())
                self.log.dump_tabular()
                if loss < minloss:
                    torch.save(
                        {
                            'epoch': batch_idx + 1,
                            'state_dict': model.state_dict(),
                            'best_loss': minloss,
                            'optimizer': optimizer.state_dict()
                        }, self.log.path_model)
            loss.backward()
            optimizer.step()
Exemplo n.º 2
0
    def train(self, epoch, pool_rank, phi):
        model = CNN(input_channel=1)
        model = torch.nn.DataParallel(model).cuda()
        optimizer = optim.SGD(model.parameters(),
                              lr=self.inner_lr,
                              momentum=self.inner_momentum)
        cross_entropy = SoftCrossEntropy()
        model.train()

        self.target_model.eval()

        eloss = loss_net(self.nclass)
        eloss = torch.nn.DataParallel(eloss).cuda()
        for key in eloss.state_dict().keys():
            eloss.state_dict()[key] = eloss.state_dict()[key] + phi[key]
        eloss.eval()
        minloss = 10000
        # print('start train model')
        for e in range(self.inner_epoch_freq):
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data = data.cuda()
                optimizer.zero_grad()

                output = model(data)
                target_output = self.target_model(data)
                target = target.view(-1, 1)

                target = torch.zeros(data.shape[0],
                                     10).scatter_(1, target.long(), 1.0)
                varInput = torch.cat((target.cuda(), target_output), 1)
                # print(varInput.shape, target_output.shape)
                soft_target = eloss.forward(varInput)
                # print(soft_target)
                loss = cross_entropy(output, soft_target)
                if self.save_model:
                    if epoch % 20 == 0:
                        self.log.log_tabular('epoch', batch_idx)
                        self.log.log_tabular('loss', loss.item())
                    if loss < minloss:
                        torch.save(
                            {
                                'epoch': batch_idx + 1,
                                'state_dict': model.state_dict(),
                                'best_loss': minloss,
                                'optimizer': optimizer.state_dict()
                            }, self.log.path_model)
                loss.backward()
                optimizer.step()
                # print('[pool: %d] epoch %d, loss: %f' % (pool_rank, epoch, loss.item()))

            if epoch % 20 == 0 and self.save_model:
                self.log.dump_tabular()

        accuracy = self.test(model)
        return accuracy
Exemplo n.º 3
0
    test_dataset = IQADataset(dataset, config, index, "test")
    test_loader = torch.utils.data.DataLoader(test_dataset)

    model = CNN().to(device)

    criterion = nn.L1Loss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay)

    best_SROCC = -1

    for epoch in range(args.epochs):
        #train
        model.train()
        LOSS = 0
        for i, (patches, label) in enumerate(train_loader):
            patches = patches.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            outputs = model(patches)

            loss = criterion(outputs, label)

            loss.backward()
            optimizer.step()
            LOSS = LOSS + loss.item()
        train_loss = LOSS / (i + 1)
Exemplo n.º 4
0
def train(train_file_path,
          val_file_path,
          in_channels,
          num_class,
          batch_norm,
          dropout,
          n_epochs,
          batch_size,
          lr,
          momentum,
          weight_decay,
          optim_type,
          ckpt_path,
          max_ckpt_save_num,
          ckpt_save_interval,
          val_interval,
          resume,
          device='cpu'):
    '''
    The main training procedure
    ----------------------------
    :param train_file_path: file list of training image paths and labels
    :param val_file_path: file list of validation image paths and labels
    :param in_channels: channel number of image
    :param num_class: number of classes, in this task it is 26 English letters
    :param batch_norm: whether to use batch normalization in convolutional layers and linear layers
    :param dropout: dropout ratio of dropout layer which ranges from 0 to 1
    :param n_epochs: number of training epochs
    :param batch_size: batch size of training
    :param lr: learning rate
    :param momentum: only used if optim_type == 'sgd'
    :param weight_decay: the factor of L2 penalty on network weights
    :param optim_type: optimizer, which can be set as 'sgd', 'adagrad', 'rmsprop', 'adam', or 'adadelta'
    :param ckpt_path: path to save checkpoint models
    :param max_ckpt_save_num: maximum number of saving checkpoint models
    :param ckpt_save_interval: intervals of saving checkpoint models, e.g., if ckpt_save_interval = 2, then save checkpoint models every 2 epochs
    :param val_interval: intervals of validation, e.g., if val_interval = 5, then do validation after each 5 training epochs
    :param resume: path to resume model
    :param device: 'cpu' or 'cuda', we can use 'cpu' for our homework if GPU with cuda support is not available
    '''

    # construct training and validation data loader
    train_loader = dataLoader(train_file_path,
                              norm_size=(32, 32),
                              batch_size=batch_size)
    val_loader = dataLoader(val_file_path, norm_size=(32, 32), batch_size=1)

    model = CNN(in_channels, num_class, batch_norm, dropout)

    # put the model on CPU or GPU
    model = model.to(device)

    # define loss function and optimizer
    loss_func = nn.CrossEntropyLoss()

    if optim_type == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif optim_type == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(),
                                  lr,
                                  weight_decay=weight_decay)
    elif optim_type == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  lr,
                                  weight_decay=weight_decay)
    elif optim_type == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr,
                               weight_decay=weight_decay)
    elif optim_type == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(),
                                   lr,
                                   weight_decay=weight_decay)
    else:
        print(
            '[Error] optim_type should be one of sgd, adagrad, rmsprop, adam, or adadelta'
        )
        raise NotImplementedError

    if resume is not None:
        print('[Info] resuming model from %s ...' % resume)
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])

    # training
    # to save loss of each training epoch in a python "list" data structure
    losses = []
    # to save accuracy on validation set of each training epoch in a python "list" data structure
    accuracy_list = []
    val_epochs = []

    print('training...')
    for epoch in range(n_epochs):
        # set the model in training mode
        model.train()

        # to save total loss in one epoch
        total_loss = 0.

        for step, (input,
                   label) in enumerate(train_loader):  # get a batch of data

            # set data type and device
            input, label = input.type(torch.float).to(device), label.type(
                torch.long).to(device)

            # clear gradients in the optimizer
            optimizer.zero_grad()

            # run the model which is the forward process
            out = model(input)

            # compute the CrossEntropy loss, and call backward propagation function
            loss = loss_func(out, label)
            loss.backward()

            # update parameters of the model
            optimizer.step()

            # sum up of total loss, loss.item() return the value of the tensor as a standard python number
            # this operation is not differentiable
            total_loss += loss.item()

        # average of the total loss for iterations
        avg_loss = total_loss / len(train_loader)
        losses.append(avg_loss)

        # evaluate model on validation set
        if (epoch + 1) % val_interval == 0:
            val_accuracy = eval_one_epoch(model, val_loader, device)
            accuracy_list.append(val_accuracy)
            val_epochs.append(epoch)
            print(
                'Epoch {:02d}: loss = {:.3f}, accuracy on validation set = {:.3f}'
                .format(epoch + 1, avg_loss, val_accuracy))

        if (epoch + 1) % ckpt_save_interval == 0:
            # get info of all saved checkpoints
            ckpt_list = glob.glob(os.path.join(ckpt_path, 'ckpt_epoch_*.pth'))
            # sort checkpoints by saving time
            ckpt_list.sort(key=os.path.getmtime)
            # remove surplus ckpt file if the number is larger than max_ckpt_save_num
            if len(ckpt_list) >= max_ckpt_save_num:
                for cur_file_idx in range(
                        0,
                        len(ckpt_list) - max_ckpt_save_num + 1):
                    os.remove(ckpt_list[cur_file_idx])

            # save model parameters in a file
            ckpt_name = os.path.join(ckpt_path,
                                     'ckpt_epoch_%d.pth' % (epoch + 1))
            save_dict = {
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'configs': {
                    'in_channels': in_channels,
                    'num_class': num_class,
                    'batch_norm': batch_norm,
                    'dropout': dropout
                }
            }

            torch.save(save_dict, ckpt_name)
            print('Model saved in {}\n'.format(ckpt_name))

    plot(losses, accuracy_list, val_epochs, ckpt_path)