Пример #1
0
class Trainer:
    def __init__(self, device, config, model, criterion, dataloader, data_transformer=None, tensorboard=True, meta_data=None):
        super().__init__()

        config['running_loss_range'] = config['running_loss_range'] if 'running_loss_range' in config else 50

        self.device = torch.device('cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu')
        self.config = config

        self.model = model.to(device)
        self.criterion = criterion

        self.ds_length = len(dataloader.dataset)
        self.batch_size = dataloader.batch_size
        self.dataloader = dataloader
        self.data_transformer = data_transformer

        self.epoch_loss_history = []
        self.content_loss_history = []
        self.style_loss_history = []
        self.total_variation_loss_history = []
        self.loss_history = []
        self.lr_history = []
        self.meta_data = meta_data

        self.progress_bar = trange(
            math.ceil(self.ds_length / self.batch_size) * self.config['epochs'],
            leave=True
        )

        if self.config['lr_scheduler'] == 'CyclicLR':
            self.optimizer = optim.SGD(self.model.parameters(), lr=config['max_lr'], nesterov=True, momentum=0.9)
            self.scheduler = CyclicLR(
                optimizer=self.optimizer,
                base_lr=config['min_lr'],
                max_lr=config['max_lr'],
                step_size_up=self.config['lr_step_size'],
                mode='triangular2'
            )
        elif self.config['lr_scheduler'] == 'CosineAnnealingLR':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = CosineAnnealingLR(
                optimizer=self.optimizer,
                T_max=self.config['lr_step_size']
            )
        elif self.config['lr_scheduler'] == 'ReduceLROnPlateau':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = ReduceLROnPlateau(
                optimizer=self.optimizer,
                patience=100
            )
        elif self.config['lr_scheduler'] == 'StepLR':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = StepLR(
                optimizer=self.optimizer,
                step_size=self.config['lr_step_size'],
                gamma=float(self.config['lr_multiplicator'])
            )
        elif self.config['lr_scheduler'] == 'CosineAnnealingWarmRestarts':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = CosineAnnealingWarmRestarts(
                optimizer=self.optimizer,
                T_0=self.config['lr_step_size'],
                eta_min=config['min_lr'],
                T_mult=int(self.config['lr_multiplicator'])
            )
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = None

        if tensorboard:
            self.tensorboard_writer = SummaryWriter(log_dir=os.path.join('./runs', config['name']))
        else:
            self.tensorboard_writer = None

    def load_checkpoint(self):
        name = self.config['name']
        path = f'./checkpoints/{name}.pth'

        if os.path.exists(path):
            checkpoint = torch.load(path)

            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            if self.scheduler:
                self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            self.content_loss_history = checkpoint['content_loss_history']
            self.style_loss_history = checkpoint['style_loss_history']
            self.total_variation_loss_history = checkpoint['total_variation_loss_history']
            self.loss_history = checkpoint['loss_history']
            self.lr_history = checkpoint['lr_history']

            del checkpoint
            torch.cuda.empty_cache()

    def train(self):
        start = time()

        for epoch in range(self.config['epochs']):
            self.epoch_loss_history = []

            for i, batch in enumerate(self.dataloader):
                self.do_training_step(i, batch)
                self.do_progress_bar_step(epoch, self.config['epochs'], i)

                if self.config['lr_scheduler'] == 'ReduceLROnPlateau':
                    self.scheduler.step(self.loss_history[-1])
                elif self.scheduler:
                    self.scheduler.step()

                if i % self.config['save_checkpoint_interval'] == 0:
                    self.save_checkpoint(f'./checkpoints/{self.config["name"]}.pth')

                if time() - start >= self.config['max_runtime'] != 0:
                    break

        torch.cuda.empty_cache()

    def do_training_step(self, i, batch):
        self.model.train()

        with torch.autograd.detect_anomaly():
            try:

                if self.data_transformer:
                    x, y = self.data_transformer(batch)
                else:
                    x, y = batch

                x = x.to(self.device)
                y = y.to(self.device)
                self.optimizer.zero_grad()

                preds = self.model(x)

                loss = self.criterion(preds, y)
                loss.backward()
                self.optimizer.step()

                self.lr_history.append(self.optimizer.param_groups[0]['lr'])
                self.epoch_loss_history.append(self.criterion.loss_val)

                self.content_loss_history.append(self.criterion.content_loss_val)
                self.style_loss_history.append(self.criterion.style_loss_val)
                self.total_variation_loss_history.append(self.criterion.total_variation_loss_val)
                self.loss_history.append(self.criterion.loss_val)

                if self.tensorboard_writer and i % self.config['save_checkpoint_interval'] == 0:
                    grid_y = torchvision.utils.make_grid(y)
                    grid_preds = torchvision.utils.make_grid(preds)

                    self.tensorboard_writer.add_image('Inputs', grid_y, 0)
                    self.tensorboard_writer.add_image('Predictions', grid_preds, 0)

                    # writer.add_graph(network, images)
                    self.tensorboard_writer.add_scalar(
                        'Content Loss',
                        self.content_loss_history[-1],
                        len(self.content_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Style Loss',
                        self.style_loss_history[-1],
                        len(self.style_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'TV Loss',
                        self.total_variation_loss_history[-1],
                        len(self.total_variation_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Total Loss',
                        self.loss_history[-1],
                        len(self.loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Learning Rate',
                        self.lr_history[-1],
                        len(self.lr_history) - 1
                    )

                    self.tensorboard_writer.close()
            except:
                self.load_checkpoint()

    def do_validation_step(self):
        self.model.eval()

    def do_progress_bar_step(self, epoch, epochs, i):
        avg_epoch_loss = sum(self.epoch_loss_history) / (i + 1)

        if len(self.loss_history) >= self.config['running_loss_range']:
            running_loss = sum(
                self.loss_history[-self.config['running_loss_range']:]
            ) / self.config['running_loss_range']
        else:
            running_loss = 0

        if len(self.loss_history) > 0:
            self.progress_bar.set_description(
                f'Name: {self.config["name"]}, ' +
                f'Loss Network: {self.config["loss_network"]}, ' +
                f'Epoch: {epoch + 1}/{epochs}, ' +
                f'Average Epoch Loss: {avg_epoch_loss:,.2f}, ' +
                f'Running Loss: {running_loss:,.2f}, ' +
                f'Loss: {self.loss_history[-1]:,.2f}, ' +
                f'Learning Rate: {self.lr_history[-1]:,.6f}'
            )
        else:
            self.progress_bar.set_description(
                f'Name: {self.config["name"]}, ' +
                f'Loss Network: {self.config["loss_network"]}, ' +
                f'Epoch: {epoch + 1}/{epochs}, ' +
                f'Average Epoch Loss: {0:,.2f}, ' +
                f'Running Loss: {0:,.2f}, ' +
                f'Loss: {0:,.2f}, ' +
                f'Learning Rate: {0:,.6f}'
            )

        self.progress_bar.update(1)
        self.progress_bar.refresh()

    def save_checkpoint(self, path):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'content_loss_history': self.content_loss_history,
            'style_loss_history': self.style_loss_history,
            'total_variation_loss_history': self.total_variation_loss_history,
            'loss_history': self.loss_history,
            'lr_history': self.lr_history,
            'content_image_size': self.config['content_image_size'],
            'style_image_size': self.config['style_image_size'],
            'network': str(self.config['network']),
            'content_weight': self.config['content_weight'],
            'style_weight': self.config['style_weight'],
            'total_variation_weight': self.config['total_variation_weight'],
            'bottleneck_size': self.config['bottleneck_size'],
            'bottleneck_type': str(self.config['bottleneck_type']),
            'channel_multiplier': self.config['channel_multiplier'],
            'expansion_factor': self.config['expansion_factor'],
            'intermediate_activation_fn': self.config['intermediate_activation_fn'],
            'final_activation_fn': self.config['final_activation_fn'],
            'meta_data': self.meta_data
        }, path)
Пример #2
0
def train():
    epoch_size = len(trainset) // args.batch_size
    num_epochs = math.ceil(args.max_iter / epoch_size)
    start_epoch = 0
    iteration = 0

    df = pd.read_csv('./data/train.csv')
    tmp = np.sqrt(1 / np.sqrt(df['landmark_id'].value_counts().sort_index().values))
    margins = (tmp - tmp.min()) / (tmp.max() - tmp.min()) * 0.45 + 0.05

    print('Loading model...')
    model = EfficientNetLandmark(args.depth, args.num_classes)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, num_epochs-1)

    if args.resume is not None:
        state_dict = torch.load(args.resume)
        try:
            print('Resume all state...')
            modules_state_dict = state_dict['modules']
            optimizer_state_dict = state_dict['optimizer']
            scheduler_state_dict = state_dict['scheduler']
            optimizer.load_state_dict(optimizer_state_dict)
            scheduler.load_state_dict(scheduler_state_dict)
            start_epoch = state_dict['epoch']
            iteration = state_dict['iteration']
        except KeyError:
            print('Resume only modules...')
            modules_state_dict = state_dict
        
        model_state_dict = {k.replace('module.', ''): v for k, v in modules_state_dict.items() if k.replace('module.', '') in model.state_dict().keys()}
        model.load_state_dict(model_state_dict)

    num_gpus = list(range(torch.cuda.device_count()))
    if len(num_gpus) > 1:
        print('Using data parallel...')
        model = nn.DataParallel(model, device_ids=num_gpus)

    # logger = open('log.txt', 'w')
    losses = AverageMeter()
    scores = AverageMeter()

    start_train = datetime.now()
    print(num_epochs, start_epoch, iteration)
    model.train()
    for epoch in range(start_epoch, num_epochs):
        if (epoch+1)*epoch_size < iteration:
            continue

        if iteration == args.max_iter:
            break
        
        correct = 0
        input_size = 0
        start_time = datetime.now()
        for i, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()

            inputs = inputs.to(device)
            targets = targets.to(device)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets, margins)
            
            confs, preds = torch.max(outputs.detach(), dim=1)
            # loss.backward()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            # optimizer.step()
            scaler.update()

            losses.update(loss.item(), inputs.size(0))
            scores.update(gap(preds, confs, targets))
            correct += (preds == targets).float().sum()
            input_size += inputs.size(0)

            iteration += 1

            writer.add_scalar('loss', losses.val, iteration)
            writer.add_scalar('gap', scores.val, iteration)

            # log = {'epoch': epoch+1, 'iteration': iteration, 'loss': losses.val, 'acc': corrects.val, 'gap': scores.val}
            # logger.write(str(log) + '\n')
            if iteration % args.verbose_eval == 0:
                print(
                    f'[{epoch+1}/{iteration}] Loss: {losses.val:.5f} Acc: {correct/input_size:.5f}' \
                    f' GAP: {scores.val:.5f} LR: {optimizer.param_groups[0]["lr"]} Time: {datetime.now() - start_time}')
            
            if iteration > 100000 and iteration % args.save_interval == 0:
                print('Save model...')
                save_checkpoint({
                    'modules': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'epoch': epoch,
                    'iteration': iteration,
                }, f'effnet_b{args.depth}_{args.max_size}_{args.batch_size}_{epoch+1}_{iteration}.pth')

            scheduler.step(epoch+i / len(train_loader))
        print()

    # logger.close()
    writer.close()
    print(datetime.now() - start_train)
Пример #3
0
def run_training(opt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    work_dir, epochs, train_batch, valid_batch, weights = \
        opt.work_dir, opt.epochs, opt.train_bs, opt.valid_bs, opt.weights

    # Directories
    last = os.path.join(work_dir, 'last.pt')
    best = os.path.join(work_dir, 'best.pt')

    # --------------------------------------
    # Setup train and validation set
    # --------------------------------------
    data = pd.read_csv(opt.train_csv)
    images_path = opt.data_dir

    n_classes = 6  # fixed coding :V

    data['class'] = data.apply(lambda row: categ[row["class"]], axis=1)

    train_loader, val_loader = prepare_dataloader(data,
                                                  opt.fold,
                                                  train_batch,
                                                  valid_batch,
                                                  opt.img_size,
                                                  opt.num_workers,
                                                  data_root=images_path)

    # if not opt.ovr_val:
    #     handwritten_data = pd.read_csv(opt.handwritten_csv)
    #     printed_data = pd.read_csv(opt.printed_csv)
    #     handwritten_data['class'] = handwritten_data.apply(lambda row: categ[row["class"]], axis =1)
    #     printed_data['class'] = printed_data.apply(lambda row: categ[row["class"]], axis =1)
    #     _, handwritten_val_loader = prepare_dataloader(
    #         handwritten_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    #     _, printed_val_loader = prepare_dataloader(
    #         printed_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    # --------------------------------------
    # Models
    # --------------------------------------

    model = Classifier(model_name=opt.model_name,
                       n_classes=n_classes,
                       pretrained=True).to(device)

    if opt.weights is not None:
        cp = torch.load(opt.weights)
        model.load_state_dict(cp['model'])

    # -------------------------------------------
    # Setup optimizer, scheduler, criterion loss
    # -------------------------------------------

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,
                                            T_mult=1,
                                            eta_min=1e-6,
                                            last_epoch=-1)
    scaler = GradScaler()

    loss_tr = nn.CrossEntropyLoss().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    # --------------------------------------
    # Setup training
    # --------------------------------------
    if os.path.exists(work_dir) == False:
        os.mkdir(work_dir)

    best_loss = 1e5
    start_epoch = 0
    best_epoch = 0  # for early stopping

    if opt.resume == True:
        checkpoint = torch.load(last)

        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint["scheduler"])
        best_loss = checkpoint["best_loss"]

    # --------------------------------------
    # Start training
    # --------------------------------------
    print("[INFO] Start training...")
    for epoch in range(start_epoch, epochs):
        train_one_epoch(epoch,
                        model,
                        loss_tr,
                        optimizer,
                        train_loader,
                        device,
                        scheduler=scheduler,
                        scaler=scaler)
        with torch.no_grad():
            if opt.ovr_val:
                val_loss = valid_one_epoch_overall(epoch,
                                                   model,
                                                   loss_fn,
                                                   val_loader,
                                                   device,
                                                   scheduler=None)
            else:
                val_loss = valid_one_epoch(epoch,
                                           model,
                                           loss_fn,
                                           handwritten_val_loader,
                                           printed_val_loader,
                                           device,
                                           scheduler=None)

            if val_loss < best_loss:
                best_loss = val_loss
                best_epoch = epoch
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_loss': best_loss
                    }, os.path.join(best))

                print('best model found for epoch {}'.format(epoch + 1))

        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_loss': best_loss
            }, os.path.join(last))

        if epoch - best_epoch > opt.patience:
            print("Early stop achieved at", epoch + 1)
            break

    del model, optimizer, train_loader, val_loader, scheduler, scaler
    torch.cuda.empty_cache()
Пример #4
0
def main():
    global args, best_loss, weight_decay, momentum

    net = Network()

    epoch = 0
    saved = load_checkpoint()

    # 这里还需要做修改
    dataTrain = get_train_set(
        data_dir='C:/Users/hasee/Desktop/Master_Project/Step2/Plan_B/Label')
    dataVal = get_val_set(
        data_dir='C:/Users/hasee/Desktop/Master_Project/Step2/Plan_B/Label')

    train_loader = torch.utils.data.DataLoader(dataTrain,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(dataVal,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=workers,
                                             pin_memory=True)

    # 效果如果还可以的话可以考虑去掉weight_decay再试试看
    optimizer = torch.optim.SGD(net.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    # 余弦退火
    if Cosine_lr:
        lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)
    else:
        lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.92)

    if doLoad:
        if saved:
            print('Loading checkpoint for epoch %05d ...' % (saved['epoch']))
            state = saved['model_state']
            try:
                net.module.load_state_dict(state)
            except:
                net.load_state_dict(state)
            epoch = saved['epoch']
            best_loss = saved['best_loss']
            optimizer = saved['optim_state']
            lr_scheduler = saved['scheule_state']
        else:
            print('Warning: Could not read checkpoint!')

    # Quick test
    if doTest:
        validate(val_loader, net, epoch)
        return
    '''
    for epoch in range(0, epoch):
        adjust_learning_rate(optimizer, epoch)
    '''

    m.begin_run(train_loader)
    print("start to run!")
    for epoch in range(epoch, epochs):

        # adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, net, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        loss_total = validate(val_loader, net, epoch)

        # remember best loss and save checkpoint
        is_best = loss_total < best_loss
        best_loss = min(loss_total, best_loss)
        state = {
            'epoch': epoch + 1,
            'model_state': net.state_dict(),
            'optim_state': optimizer.state_dict(),
            'scheule_state': lr_scheduler.state_dict(),
            'best_loss': best_loss,
        }
        save_checkpoint(state, is_best)
    # 结束一次运行
    m.end_run()
def main(opt):
    torch.manual_seed(opt.seed)

    if os.path.isdir("./artifacts_train"):
        log_versions = [
            int(name.split("_")[-1]) 
            for name in os.listdir(os.path.join("./artifacts_train", "logs")) 
            if os.path.isdir(os.path.join("./artifacts_train", "logs", name))
        ]
        current_version = f"version_{max(log_versions) + 1}"
    else:
        os.makedirs(os.path.join("./artifacts_train", "logs"), exist_ok=True)
        os.makedirs(os.path.join("./artifacts_train", "checkpoints"), exist_ok=True)
        current_version = "version_0"
    logger = SummaryWriter(logdir=os.path.join("./artifacts_train", "logs", current_version))
    os.makedirs(os.path.join("./artifacts_train", "checkpoints", current_version), exist_ok=True)

    device = torch.device("cuda:0")

    # Train Val Split
    path_to_train = os.path.join(opt.data_dir, "train_images/")
    train_df = pd.read_csv(os.path.join(opt.data_dir, "train.csv"))
    train_df["image_id"] = train_df["image_id"].apply(lambda x: os.path.join(path_to_train, x))
    train_df = train_df.sample(frac=1, random_state=opt.seed).reset_index(drop=True)
    train_df.columns = ["path", "label"]

    val_df = train_df.loc[int(len(train_df)*opt.train_split/100):].reset_index(drop=True)
    train_df = train_df.loc[:int(len(train_df)*opt.train_split/100)].reset_index(drop=True)

    # Augmentations
    train_trans = albu.Compose([
            albu.RandomResizedCrop(*opt.input_shape),
            albu.VerticalFlip(),
            albu.HorizontalFlip(),
            albu.Transpose(p=0.5),
            albu.ShiftScaleRotate(p=0.5),
            albu.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            albu.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            albu.CoarseDropout(p=0.5),
            albu.Cutout(p=0.5),
            albu.Normalize()
        ])

    val_trans = albu.Compose([
            albu.RandomResizedCrop(*opt.input_shape),
            albu.VerticalFlip(),
            albu.HorizontalFlip(),
            albu.ShiftScaleRotate(p=0.5),
            albu.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            albu.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            albu.Normalize() 
        ])

    # Dataset init
    data_train = LeafData(train_df, transforms=train_trans)
    data_val = LeafData(val_df, transforms=val_trans, tta=opt.tta)
    weights = get_weights(train_df)
    sampler_train = WeightedRandomSampler(weights, len(data_train))
    dataloader_train = DataLoader(data_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
    dataloader_val = DataLoader(data_val, shuffle=True, batch_size=8, num_workers=opt.num_workers)

    # Model init
    model = timm.create_model(opt.model_arch, pretrained=False)
    pretrained_path = get_pretrained(opt)
    model.load_state_dict(torch.load(pretrained_path, map_location=device))
    model.classifier = nn.Linear(model.classifier.in_features, 5)
    model.to(device)

    # freeze first opt.freeze_percent params
    param_count = len(list(model.parameters()))
    for param in list(model.parameters())[:int(param_count*opt.freeze_percent/100)]:
        param.requires_grad = False

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=opt.num_epoch, T_mult=1, eta_min=1e-6, last_epoch=-1)
    criterion = get_loss(opt)

    best_acc = 0
    iteration_per_epoch = opt.iteration_per_epoch if opt.iteration_per_epoch else len(dataloader_train)
    for epoch in range(opt.num_epoch):
        # Train
        model.train()
        dataloader_iterator = iter(dataloader_train)
        pbar = tqdm(range(iteration_per_epoch), desc=f"Train : Epoch: {epoch + 1}/{opt.num_epoch}")
        
        for step in pbar:     
            try:
                images, labels = next(dataloader_iterator)
            except:
                dataloader_iterator = iter(dataloader_train)
                images, labels = next(dataloader_iterator)
            
            images = images.to(device)
            labels = labels.to(device)

            if opt.adversarial_attack:
                idx_for_attack = list(rng.choice(labels.size(0), size=labels.size(0) // 4))
                images[idx_for_attack] = pgd_attack(images[idx_for_attack], labels[idx_for_attack], model, criterion)
            
            logit = model(images)
            loss = criterion(logit, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step(epoch + step/iteration_per_epoch)
            
            _, predicted = torch.max(logit.data, 1)
            accuracy = 100 * (predicted == labels).sum().item() / labels.size(0)
            pbar.set_postfix({"Accuracy": accuracy, "Loss": loss.cpu().data.numpy().item(), "LR": optimizer.param_groups[0]["lr"]})
            logger.add_scalar('Loss/Train', loss.cpu().data.numpy().item(), epoch*iteration_per_epoch + step + 1)
            logger.add_scalar('Accuracy/Train', accuracy, epoch*iteration_per_epoch + step + 1)
            logger.add_scalar('LR/Train', optimizer.param_groups[0]["lr"], epoch*iteration_per_epoch + step + 1)
        
        # Val
        print(f"Eval start! Epoch {epoch + 1}/{opt.num_epoch}")
        correct = 0
        total = 0
        loss_sum = 0

        model.eval()
        dataloader_iterator = iter(dataloader_val)
        pbar = tqdm(range(len(dataloader_val)), desc=f"Eval : Epoch: {epoch + 1}/{opt.num_epoch}")
        for step in pbar: 
            try:
                images, labels = next(dataloader_iterator)
            except:
                dataloader_iterator = iter(dataloader_val)
                images, labels = next(dataloader_iterator)

            labels = labels.to(device)

            if opt.tta:
                predicts = []
                loss_tta = 0
                for i in range(opt.tta):
                    img = images[i].to(device)
            
                    with torch.no_grad():
                        logit = model(img)

                    loss = criterion(logit, labels)
                    loss_tta += loss.cpu().data.numpy().item() / opt.tta
                    predicts.append(F.softmax(logit, dim=-1)[None, ...])
                
                predicts = torch.cat(predicts, dim=0).mean(dim=0)
                loss_sum += loss_tta

            else:
                images = images.to(device)
                with torch.no_grad():
                    logit = model(images)

                loss = criterion(logit, labels)
                predicts = F.softmax(logit, dim=-1)
                loss_sum += loss.cpu().data.numpy().item()

            #accuracy
            _, predicted = torch.max(predicts.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix({"Accuracy": 100 * (predicted == labels).sum().item() / labels.size(0), "Loss": loss.cpu().data.numpy().item()})
            
        accuracy = 100 * correct / total
        loss_mean = loss_sum / len(dataloader_val)
        logger.add_scalar('Loss/Val', loss_mean, epoch*iteration_per_epoch + step + 1)
        logger.add_scalar('Accuracy/Val', accuracy, epoch*iteration_per_epoch + step + 1)
        print(f"Epoch: {epoch + 1}, Accuracy: {accuracy:.5f}, Loss {loss_mean:.5f}")
        if accuracy > best_acc:
            print("Saved checkpoint!")
            best_acc = accuracy
            torch.save({
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "accuracy": round(accuracy, 5),
                    "loss": round(loss_mean, 5),
                    "config": opt,
                }, 
                os.path.join("./artifacts_train", "checkpoints", current_version, f"{epoch + 1}_accuracy_{accuracy:.5f}.pth"))
Пример #6
0
def main(args):

    args = parse_args()
    tag = args.tag
    device = torch.device('cuda:0')

    no_epochs = args.epochs
    batch_size = args.batch

    linear_hidden = args.linear
    conv_hidden = args.conv

    #Get train test paths -> later on implement cross val
    steps = get_paths(as_tuples=True, shuffle=True, tag=tag)
    steps_train, steps_test = steps[:int(len(steps) *
                                         .8)], steps[int(len(steps) * .2):]

    transform = transforms.Compose(
        [DepthSegmentationPreprocess(no_data_points=1),
         ToSupervised()])

    dataset_train = SimpleDataset(ids=steps_train,
                                  batch_size=batch_size,
                                  transform=transform,
                                  **SENSORS)
    dataset_test = SimpleDataset(ids=steps_test,
                                 batch_size=batch_size,
                                 transform=transform,
                                 **SENSORS)

    dataloader_params = {
        'batch_size': batch_size,
        'shuffle': True,
        'num_workers': 8
    }  #we've already shuffled paths

    dataset_train = DataLoader(dataset_train, **dataloader_params)
    dataset_test = DataLoader(dataset_test, **dataloader_params)

    batch = next(iter(dataset_test))
    action_shape = batch['action'][0].shape
    img_shape = batch['img'][0].shape
    #Nets
    net = DDPGActor(img_shape=img_shape,
                    numeric_shape=[len(NUMERIC_FEATURES)],
                    output_shape=[2],
                    linear_hidden=linear_hidden,
                    conv_filters=conv_hidden)
    # net = DDPGCritic(actor_out_shape=action_shape, img_shape=img_shape, numeric_shape=[len(NUMERIC_FEATURES)],
    #                         linear_hidden=linear_hidden, conv_filters=conv_filters)

    print(len(steps))
    print(net)
    print(get_n_params(net))
    # save path
    net_path = f'../data/models/imitation/{DATE_TIME}/{net.name}'
    os.makedirs(net_path, exist_ok=True)
    optim_steps = args.optim_steps
    logging_idx = int(len(dataset_train.dataset) / (batch_size * optim_steps))

    writer_train = SummaryWriter(f'{net_path}/train',
                                 max_queue=30,
                                 flush_secs=5)
    writer_test = SummaryWriter(f'{net_path}/test', max_queue=1, flush_secs=5)

    #Optimizers
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=0.001,
                                 weight_decay=0.0005)

    if args.scheduler == 'cos':
        scheduler = CosineAnnealingWarmRestarts(optimizer,
                                                T_0=optim_steps,
                                                T_mult=2)
    elif args.scheduler == 'one_cycle':
        scheduler = OneCycleLR(optimizer,
                               max_lr=0.001,
                               epochs=no_epochs,
                               steps_per_epoch=optim_steps)

    #Loss function
    loss_function = torch.nn.MSELoss(reduction='sum')
    test_loss_function = torch.nn.MSELoss(reduction='sum')

    best_train_loss = 1e10
    best_test_loss = 1e10

    for epoch_idx in range(no_epochs):
        train_loss = .0
        running_loss = .0
        # critic_running_loss = .0
        avg_max_grad = 0.
        avg_avg_grad = 0.
        for idx, batch in enumerate(iter(dataset_train)):
            global_step = int((len(dataset_train.dataset) / batch_size *
                               epoch_idx) + idx)
            batch = unpack_batch(batch=batch, device=device)
            loss, grad = train(input=batch,
                               label=batch['action'],
                               net=net,
                               optimizer=optimizer,
                               loss_fn=loss_function)
            # loss, grad = train(input=batch, label=batch['q'], net=net, optimizer=optimizer, loss_fn=loss_function)

            avg_max_grad += max([element.max() for element in grad])
            avg_avg_grad += sum([element.mean()
                                 for element in grad]) / len(grad)

            running_loss += loss
            train_loss += loss

            writer_train.add_scalar(tag=f'{net.name}/running_loss',
                                    scalar_value=loss / batch_size,
                                    global_step=global_step)
            writer_train.add_scalar(tag=f'{net.name}/max_grad',
                                    scalar_value=avg_max_grad,
                                    global_step=global_step)
            writer_train.add_scalar(tag=f'{net.name}/mean_grad',
                                    scalar_value=avg_avg_grad,
                                    global_step=global_step)

            if idx % logging_idx == logging_idx - 1:
                print(
                    f'Actor Epoch: {epoch_idx + 1}, Batch: {idx+1}, Loss: {running_loss/logging_idx}, Lr: {scheduler.get_last_lr()[0]}'
                )
                if (running_loss / logging_idx) < best_train_loss:
                    best_train_loss = running_loss / logging_idx
                    torch.save(net.state_dict(), f'{net_path}/train/train.pt')

                writer_train.add_scalar(
                    tag=f'{net.name}/lr',
                    scalar_value=scheduler.get_last_lr()[0],
                    global_step=global_step)
                running_loss = 0.0
                avg_max_grad = 0.
                avg_avg_grad = 0.
                scheduler.step()

        print(
            f'{net.name} best train loss for epoch {epoch_idx+1} - {best_train_loss}'
        )
        writer_train.add_scalar(tag=f'{net.name}/global_loss',
                                scalar_value=train_loss /
                                len(dataset_train.dataset),
                                global_step=(epoch_idx + 1))
        test_loss = .0
        with torch.no_grad():
            for idx, batch in enumerate(iter(dataset_test)):
                batch = unpack_batch(batch=batch, device=device)
                pred = net(**batch)
                loss = test_loss_function(pred, batch['action'])
                # loss = test_loss_function(pred.view(-1), batch['q'])

                test_loss += loss

        if (test_loss / len(dataset_test)) < best_test_loss:
            best_test_loss = (test_loss / len(dataset_test))

        torch.save(net.state_dict(), f'{net_path}/test/test_{epoch_idx+1}.pt')

        print(f'{net.name} test loss {(test_loss/len(dataset_test)):.3f}')
        print(f'{net.name} best test loss {best_test_loss:.3f}')
        writer_test.add_scalar(tag=f'{net.name}/global_loss',
                               scalar_value=(test_loss /
                                             len(dataset_test.dataset)),
                               global_step=(epoch_idx + 1))

    torch.save(optimizer.state_dict(),
               f=f'{net_path}/{optimizer.__class__.__name__}.pt')
    torch.save(scheduler.state_dict(),
               f=f'{net_path}/{scheduler.__class__.__name__}.pt')
    json.dump(vars(args),
              fp=open(f'{net_path}/args.json', 'w'),
              sort_keys=True,
              indent=4)

    writer_train.flush()
    writer_test.flush()
    writer_train.close()
    writer_test.close()

    batch = next(iter(dataset_test))
    batch = unpack_batch(batch=batch, device=device)
    y = net(**batch)
    g = make_dot(y, params=dict(net.named_parameters()))
    g.save(filename=f'{DATE_TIME}_{net.name}.dot', directory=net_path)
    check_call([
        'dot', '-Tpng', '-Gdpi=200', f'{net_path}/{DATE_TIME}_{net.name}.dot',
        '-o', f'{net_path}/{DATE_TIME}_{net.name}.png'
    ])
Пример #7
0
        print('Loading model: {}. Resuming from epoch: {}'.format(
            args.load_model, epoch))
    else:
        print('Model: {} not found'.format(args.load_model))

for epoch in range(args.epochs):
    v_loss = execute_graph(model, loader, optimizer, scheduler, epoch,
                           use_cuda)

    if v_loss < best_loss:
        best_loss = v_loss
        print('Writing model checkpoint')
        state = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'val_loss': v_loss
        }
        t = time.localtime()
        timestamp = time.strftime('%b-%d-%Y_%H%M', t)
        file_name = 'models/{}_{}_{}_{:04.4f}.pt'.format(
            timestamp, args.uid, epoch, v_loss)

        torch.save(state, file_name)

# TensorboardX logger
logger.close()

# save model / restart trainin
Пример #8
0
class Learner:
    def __init__(self, model, train_loader, valid_loader, config):
        self.config = config
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model = model.to(self.config.device)

        self.logger = init_logger(self.config.log_dir, 'train_main.log')
        self.tb_logger = init_tb_logger(self.config.log_dir, 'train_main')
        self.log('\n'.join(
            [f"{k} = {v}" for k, v in self.config.__dict__.items()]))

        self.summary_loss = AverageMeter()
        self.evaluator = Evaluator()

        self.criterion = torch.nn.CrossEntropyLoss(
            ignore_index=self.config.ignore_index)
        self.u_criterion = torch.nn.CrossEntropyLoss(
            ignore_index=self.config.ignore_index)
        train_params = [{
            'params': getattr(model, 'encoder').parameters(),
            'lr': self.config.lr
        }, {
            'params': getattr(model, 'decoder').parameters(),
            'lr': self.config.lr * 10
        }]
        self.optimizer = RAdam(train_params,
                               weight_decay=self.config.weight_decay)

        self.scheduler = CosineAnnealingWarmRestarts(self.optimizer,
                                                     T_0=2,
                                                     T_mult=2,
                                                     eta_min=1e-6)

        self.n_ensemble = 0
        self.epoch = 0
        self.best_epoch = 0
        self.best_loss = np.inf
        self.best_score = -np.inf

    def train_one_epoch(self):
        self.model.train()
        self.summary_loss.reset()
        iters = len(self.train_loader)
        for step, (images, scribbles, weights) in enumerate(self.train_loader):
            self.tb_logger.add_scalar('Train/lr',
                                      self.optimizer.param_groups[0]['lr'],
                                      iters * self.epoch + step)
            scribbles = scribbles.to(self.config.device).long()
            images = images.to(self.config.device)
            batch_size = images.shape[0]

            self.optimizer.zero_grad()
            outputs = self.model(images)
            if self.epoch < self.config.thr_epoch:
                loss = self.criterion(outputs, scribbles)
            else:
                x_loss = self.criterion(outputs, scribbles)

                scribbles = scribbles.cpu()
                mean = weights[..., 0]
                u_labels = torch.where(
                    ((mean < (1 - self.config.thr_conf)) |
                     (mean > self.config.thr_conf)) &
                    (scribbles == self.config.ignore_index),
                    mean.round().long(),
                    self.config.ignore_index * torch.ones_like(scribbles)).to(
                        self.config.device)
                u_loss = self.u_criterion(outputs, u_labels)
                loss = x_loss + 0.5 * u_loss

            loss.backward()
            self.summary_loss.update(loss.detach().item(), batch_size)
            self.optimizer.step()
            if self.scheduler.__class__.__name__ != 'ReduceLROnPlateau':
                self.scheduler.step()

        return self.summary_loss.avg

    def validation(self):
        self.model.eval()
        self.summary_loss.reset()
        self.evaluator.reset()
        for step, (_, images, _, targets) in enumerate(self.valid_loader):
            with torch.no_grad():
                targets = targets.to(self.config.device).long()
                batch_size = images.shape[0]
                images = images.to(self.config.device)
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                targets = targets.cpu().numpy()
                outputs = torch.argmax(outputs, dim=1)
                outputs = outputs.data.cpu().numpy()
                self.evaluator.add_batch(targets, outputs)
                self.summary_loss.update(loss.detach().item(), batch_size)

        if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau':
            self.scheduler.step(self.evaluator.IoU)
        return self.summary_loss.avg, self.evaluator.IoU

    def ensemble_prediction(self):
        ds = self.train_loader.dataset
        transforms = Compose([Normalize(), ToTensorV2()])
        for idx, images in tqdm(ds.images.items(), total=len(ds)):
            augmented = transforms(image=images['image'])
            img = augmented['image'].unsqueeze(0).to(self.config.device)
            with torch.no_grad():
                pred = torch.nn.functional.softmax(self.model(img), dim=1)
            weight = torch.tensor(images['weight'])
            pred = pred.squeeze(0).cpu()
            x = pred[1]
            weight[..., 0] = self.config.alpha * x + (
                1 - self.config.alpha) * weight[..., 0]
            self.train_loader.dataset.images[idx]['weight'] = weight.numpy()
        self.n_ensemble += 1

    def fit(self, epochs):
        for e in range(epochs):
            t = time.time()
            loss = self.train_one_epoch()

            self.log(
                f'[Train] \t Epoch: {self.epoch}, loss: {loss:.5f}, time: {(time.time() - t):.2f}'
            )
            self.tb_log(loss, None, 'Train', self.epoch)

            t = time.time()
            loss, score = self.validation()

            self.log(
                f'[Valid] \t Epoch: {self.epoch}, loss: {loss:.5f}, IoU: {score:.4f}, time: {(time.time() - t):.2f}'
            )
            self.tb_log(loss, score, 'Valid', self.epoch)
            self.post_processing(loss, score)

            if (self.epoch + 1) % self.config.period_epoch == 0:
                self.log(
                    f'[Ensemble] \t the {self.n_ensemble}th Prediction Ensemble ...'
                )
                self.ensemble_prediction()

            self.epoch += 1
        self.log(
            f'best epoch: {self.best_epoch}, best loss: {self.best_loss}, best_score: {self.best_score}'
        )

    def post_processing(self, loss, score):
        if loss < self.best_loss:
            self.best_loss = loss

        if score > self.best_score:
            self.best_score = score
            self.best_epoch = self.epoch

            self.model.eval()
            torch.save(
                {
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_score': self.best_score,
                    'epoch': self.epoch,
                }, f'{os.path.join(self.config.log_dir, "best_model.pth")}')
            self.log(f'best model: {self.epoch} epoch - {score:.4f}')

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_score = checkpoint['best_score']
        self.epoch = checkpoint['epoch'] + 1

    def log(self, text):
        self.logger.info(text)

    def tb_log(self, loss, IoU, split, step):
        if loss: self.tb_logger.add_scalar(f'{split}/Loss', loss, step)
        if IoU: self.tb_logger.add_scalar(f'{split}/IoU', IoU, step)