Exemple #1
0
def train(total_epochs=1, interval=100, resume=False, ckpt_path = ''):
    print("Loading training dataset...")
    train_dset = OpenImagesDataset(root='./data/train',
                            list_file ='./data/tmp/train_images_bbox.csv',
                            transform=transform, train=True, input_size=600)

    train_loader = data.DataLoader(train_dset, batch_size=4, shuffle=True, num_workers=4, collate_fn=train_dset.collate_fn)
    
    print("Loading completed.")

    #val_dset = OpenImagesDataset(root='./data/train',
    #                  list_file='./data/tmp/train_images_bbox.csv', train=False, transform=transform, input_size=600)
    #val_loader = torch.utils.data.DataLoader(val_dset, batch_size=1, shuffle=False, num_workers=4, collate_fn=val_dset.collate_fn)

    net = RetinaNet()
    net.load_state_dict(torch.load('./model/net.pth'))

    criterion = FocalLoss()
    
    net.cuda()
    criterion.cuda()
    optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
    best_val_loss = 1000

    start_epoch=0

    if resume:
        if os.path.isfile(ckpt_path):
            print(f'Loading from the checkpoint {ckpt_path}')
            checkpoint = torch.load(ckpt_path)
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint['best_val_loss']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f'Loaded checkpoint {ckpt_path}, epoch : {start_epoch}')
        else:
            print(f'No check point found at the path {ckpt_path}')

    

    for epoch in range(start_epoch, total_epochs):
        train_one_epoch(train_loader, net, criterion, optimizer, epoch, interval)
        val_loss = 0
        #val_loss = validate(val_loader, net, criterion, interval)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint({
                'epoch': epoch+1,
                'state_dict': net.state_dict(),
                'best_val_loss': best_val_loss,
                'optimizer' : optimizer.state_dict()
            }, is_best=True)
Exemple #2
0
def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform(m.weight.data)
        if m.bias is not None:
            m.bias.data.fill_(0)


model.apply(weight_init)
model.cuda()

# label_smoothing = modeling.LabelSmoothing(len(vocab), 0, 0.1)
# label_smoothing.cuda()
focal_loss = FocalLoss(class_num=len(vocab))
focal_loss.cuda()
SAVE_EVERY = 5
PENALTY_EPOCH = -1
DRAW_LEARNING_CURVE = False
data = []

# Tokenized input
print('Tokenization...')
with open('pair.csv') as PAIR:
    for line in tqdm(PAIR):
        [text, summary, _] = line.split(',')
        texts = []
        summaries = []
        paras = text.split('<newline>')
        for para in paras:
            texts.extend(list(jieba.cut(para)))
testCSVFile = '/home/lrh/git/Evaluating_Robustness_Of_Deep_Medical_Models/Dermothsis/test.csv'
testset = ISIC(csv_file=testCSVFile, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False)
print('\ndone')

print('======>loading the model')

net = AttnVGG(num_classes=2, attention=True, normalize_attn=False)
criterion = FocalLoss()
print('done\n')

print('\nmoving models to GPU')
clf = net.cuda()
clf = torch.nn.DataParallel(clf)
cudnn.benchmark = True
criterion = criterion.cuda()
print('done\n')

learningRate = 0.001
optimizer = optim.SGD(clf.parameters(),
                      lr=learningRate,
                      momentum=0.9,
                      weight_decay=1e-4,
                      nesterov=True)
lr_lambda = lambda epoch: np.power(0.1, epoch // 10)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

epochNum = 30

pgd_params_train = {
    'ord': np.inf,
Exemple #4
0
def main():
    gflags.DEFINE_string('id', None, 'ID for Training')
    gflags.DEFINE_string('epoch', 25, 'Number of Epochs')
    gflags.DEFINE_string('pretrained', None, 'Pretrained for Resuming Training')
    gflags.DEFINE_string('threshold', 0.5, 'Threshold probability for predicting class')
    gflags.DEFINE_string('batchsize', 128, 'Batch Size')
    gflags.DEFINE_string('gpu', True, 'Use GPU or Not')
    gflags.DEFINE_string('lr', 0.001, 'Learning Rate')
    gflags.DEFINE_string('class_name', 'None', 'class name')
    gflags.FLAGS(sys.argv)
    
    # Directory Path for saving weights of Trained Model
    save_path = 'Train_id' + str(gflags.FLAGS.id)
    class_name = gflags.FLAGS.class_name
    threshold = gflags.FLAGS.threshold
    class_name = gflags.FLAGS.class_name
    writer = SummaryWriter('./runs/{}'.format(gflags.FLAGS.id))
    
    if not os.path.isdir(save_path):
        os.mkdir(save_path)
        os.mkdir(save_path + '/Checkpoint')

    train_dataset_path = 'data/train'
    val_dataset_path = 'data/valid'
    train_transform = transforms.Compose([
        ImgAugTransform(),
        ToTensor()
    ])
    valid_transform = transforms.Compose([
        ToTensor()
    ])
    
    train_dataset = TrainDataset(path=train_dataset_path, transform=valid_transform, class_name=class_name)
    val_dataset = TrainDataset(path=val_dataset_path, transform=valid_transform, class_name=class_name)
    
    sampler = WeightedRandomSampler(torch.DoubleTensor(train_dataset.weights), len(train_dataset.weights))
    
    train_dataloader = DataLoader(train_dataset, batch_size=4,
                                  pin_memory=True, num_workers=4)
    val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False,
                                pin_memory=True, num_workers=4)

    size_train = len(train_dataloader)
    size_val = len(val_dataloader)
    
    print('Number of Training Images: {}'.format(size_train))
    print('Number of Validation Images: {}'.format(size_val))
    
    # Reads class weights from a Json file
    with open('class_weights.json', 'r') as fp:
        class_weights = json.load(fp)
        
    weight = torch.tensor([1/class_weights[class_name]])
    start_epoch = 0
    
    if class_name in ['Roads', 'Railway']:
        model = DinkNet34(num_classes = 1)
    else:   
        model = Unet(n_ch=4, n_classes=1)
    
    if pretrained is not None:
        criterion = FocalLoss()
    else:
        criterion = LogDiceLoss()
    criterion1 = torch.nn.BCEWithLogitsLoss()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=float(gflags.FLAGS.lr))

    if gflags.FLAGS.gpu:
        model = model.cuda()
        criterion = criterion.cuda()
        criterion1 = criterion1.cuda()

    if gflags.FLAGS.pretrained is not None:
        weight_path = sorted(os.listdir('./weights/' + save_path+ '/Checkpoint/'), key=lambda x:float(x[:-8]))[0]
        checkpoint = torch.load('./weights/' + save_path + '/Checkpoint/' + weight_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('Loaded Checkpoint of Epoch: {}'.format(gflags.FLAGS.weight))

    for epoch in range(start_epoch, int(gflags.FLAGS.epoch) + start_epoch):
        print("epoch {}".format(epoch))
        train(model, train_dataloader, criterion, criterion1, optimizer, epoch, writer, size_train, threshold)
        print('')
        val_loss = val(model, val_dataloader, criterion, criterion1, epoch, writer, size_val, threshold)
        print('')
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename= save_path + '/Checkpoint/' + str(val_loss) + '.pth.tar')
    writer.export_scalars_to_json(save_path + 'log.json')
class Train:
    def __init__(self, model, trainloader, valloader, args):
        self.model = model
        self.model_dict = self.model.state_dict()
        self.trainloader = trainloader
        self.valloader = valloader
        self.args = args
        self.start_epoch = 0
        self.best_top1 = 0.0

        # Loss function and Optimizer
        self.loss = None
        self.optimizer = None
        self.create_optimization()

        # Model Loading
        self.load_pretrained_model()
        self.load_checkpoint(self.args.resume_from)

        # Tensorboard Writer
        self.summary_writer = SummaryWriter()

    def train(self):
        for cur_epoch in range(self.start_epoch, self.args.num_epochs):

            # Initialize tqdm
            tqdm_batch = tqdm(self.trainloader,
                              desc="Epoch-" + str(cur_epoch) + "-")

            # Learning rate adjustment
            self.adjust_learning_rate(self.optimizer, cur_epoch)

            # Meters for tracking the average values
            loss, top1, top5 = AverageTracker(), AverageTracker(
            ), AverageTracker()

            # Set the model to be in training mode (for dropout and batchnorm)
            self.model.train()

            for data, target in tqdm_batch:

                if self.args.cuda:
                    data, target = data.cuda(
                        async=self.args.async_loading), target.cuda(
                            async=self.args.async_loading)
                data_var, target_var = Variable(data), Variable(target)

                # Forward pass
                output = self.model(data_var)
                cur_loss = self.loss(output, target_var)

                # Optimization step
                self.optimizer.zero_grad()
                cur_loss.backward()
                self.optimizer.step()

                # Top-1 and Top-5 Accuracy Calculation
                cur_acc1, cur_acc5 = self.compute_accuracy(output.data,
                                                           target,
                                                           topk=(1, 5))
                loss.update(cur_loss.item())
                top1.update(cur_acc1.item())
                top5.update(cur_acc5.item())

            # Summary Writing
            self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg,
                                           cur_epoch)
            self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg,
                                           cur_epoch)

            # Print in console
            tqdm_batch.close()
            print("Epoch-" + str(cur_epoch) + " | " + "loss: " +
                  str(loss.avg) + " - acc-top1: " + str(top1.avg)[:7] +
                  "- acc-top5: " + str(top5.avg)[:7])

            # Evaluate on Validation Set
            if cur_epoch % self.args.test_every == 0 and self.valloader:
                self.test(self.valloader, cur_epoch)

            # Checkpointing
            is_best = top1.avg > self.best_top1
            self.best_top1 = max(top1.avg, self.best_top1)
            self.save_checkpoint(
                {
                    'epoch': cur_epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'best_top1': self.best_top1,
                    'optimizer': self.optimizer.state_dict(),
                }, is_best)

    def test(self, testloader, cur_epoch=-1):
        loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker()

        # Set the model to be in testing mode (for dropout and batchnorm)
        self.model.eval()

        for data, target in testloader:
            if self.args.cuda:
                data, target = data.cuda(
                    async=self.args.async_loading), target.cuda(
                        async=self.args.async_loading)
            data_var, target_var = Variable(data), Variable(target)

            # Forward pass
            with torch.no_grad():
                output = self.model(data_var)
            cur_loss = self.loss(output, target_var)

            # Top-1 and Top-5 Accuracy Calculation
            cur_acc1, cur_acc5 = self.compute_accuracy(output.data,
                                                       target,
                                                       topk=(1, 5))
            loss.update(cur_loss.item())
            top1.update(cur_acc1.item())
            top5.update(cur_acc5.item())

        if cur_epoch != -1:
            # Summary Writing
            self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch)
            self.summary_writer.add_scalar("test-top-1-acc", top1.avg,
                                           cur_epoch)
            self.summary_writer.add_scalar("test-top-5-acc", top5.avg,
                                           cur_epoch)

        print("Test Results" + " | " + "loss: " + str(loss.avg) +
              " - acc-top1: " + str(top1.avg)[:7] + "- acc-top5: " +
              str(top5.avg)[:7])

    def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'):
        torch.save(state, self.args.checkpoint_dir + filename)
        if is_best:
            shutil.copyfile(self.args.checkpoint_dir + filename,
                            self.args.checkpoint_dir + 'model_best.pth.tar')

    def compute_accuracy(self, output, target, topk=(1, )):
        """Computes the accuracy@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)

        _, idx = output.topk(maxk, 1, True, True)
        idx = idx.t()
        correct = idx.eq(target.view(1, -1).expand_as(idx))

        acc_arr = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            acc_arr.append(correct_k.mul_(1.0 / batch_size))
        return acc_arr

    def adjust_learning_rate(self, optimizer, epoch):
        """Sets the learning rate to the initial LR multiplied by 0.98 every epoch"""
        learning_rate = self.args.learning_rate * (
            self.args.learning_rate_decay**epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    def create_optimization(self):
        if self.args.loss_function == 'FocalLoss':
            self.loss = FocalLoss(gamma=self.args.gamma)
        else:
            self.loss = nn.CrossEntropyLoss()

        if self.args.cuda:
            self.loss.cuda()

        if self.args.classify:
            self.metric_fc = ArcMarginModel(self.args)
            self.optimizer = RMSprop(self.model.parameters(),
                                     self.args.learning_rate,
                                     momentum=self.args.momentum,
                                     weight_decay=self.args.weight_decay)

        else:
            self.optimizer = RMSprop([{
                'params': self.model.parameters()
            }, {
                'params': self.metric_fc.parameters()
            }],
                                     self.args.learning_rate,
                                     momentum=self.args.momentum,
                                     weight_decay=self.args.weight_decay)

    def load_pretrained_model(self):
        try:
            print("Loading ImageNet pretrained weights...")
            pretrained_dict = torch.load(self.args.pretrained_path)
            #self.model.load_state_dict(pretrained_dict)
            for params_name in pretrained_dict:
                if params_name in self.model_dict and pretrained_dict[
                        params_name].size(
                        ) == self.model_dict[params_name].size():
                    self.model.state_dict()[params_name].copy_(
                        pretrained_dict[params_name])
            print("ImageNet pretrained weights loaded successfully.\n")
        except:
            print("No ImageNet pretrained weights exist. Skipping...\n")

    def load_checkpoint(self, filename):
        filename = self.args.checkpoint_dir + filename
        try:
            print("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)
            self.start_epoch = checkpoint['epoch']
            self.best_top1 = checkpoint['best_top1']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("Checkpoint loaded successfully from '{}' at (epoch {})\n".
                  format(self.args.checkpoint_dir, checkpoint['epoch']))
        except:
            print("No checkpoint exists from '{}'. Skipping...\n".format(
                self.args.checkpoint_dir))