Ejemplo n.º 1
0
Archivo: train.py Proyecto: v3551G/AAI
def main():
    train_data, test_data, _ = train_validate_test_loader(
        "../data/Childers/M/speech",
        "../data/Childers/M/egg",
        split={
            "train": 0.7,
            "validate": 0.1,
            "test": 0.2
        },
        batch_size=1,
        workers=2,
        stride={
            "train": 2,
            "validate": 20
        },
        pin_memory=False,
        model_folder="data/irish_clean_data",
    )

    model_G = SpeechEggEncoder()
    model_D = Discriminator()
    save_model = Saver("checkpoints/vmodels/childers_clean_l2")

    encoder = EGGEncoder()
    save_encoder = Saver_Encoder("encoder")
    encoder, _, _ = save_encoder.load_checkpoint(encoder,
                                                 file_name="epoch_65.pt")

    use_cuda = True
    epochs = 100

    optimizer_G = optim.Adam(list(model_G.parameters())[:12], lr=2e-3)
    optimizer_R = optim.Adam(model_G.parameters(), lr=2e-3)
    optimizer_D = optim.Adam(model_D.parameters(), lr=2e-3)
    scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, 10, 0.9)
    scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, 10, 0.9)
    scheduler_R = optim.lr_scheduler.StepLR(optimizer_D, 10, 0.5)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        for i in range(1, epochs + 1):

            net_loss, D_loss, G_loss, R_loss, D_real_prob, D_fake_prob = train(
                model_G,
                model_D,
                encoder,
                optimizer_G,
                optimizer_R,
                optimizer_D,
                train_data,
                use_cuda,
            )
            print(
                "Train loss {:4.4} D_loss {:4.4} G_loss {:4.4} reconstruction loss {:4.4} Real D prob. {:4.4} Fake D prob. {:4.4} @epoch {}"
                .format(net_loss, D_loss, G_loss, R_loss, D_real_prob,
                        D_fake_prob, i))
            if i % 5 == 0:
                checkpoint = save_model.create_checkpoint(
                    model_G,
                    model_D,
                    optimizer_G,
                    optimizer_R,
                    optimizer_D,
                    {
                        "win": 100,
                        "stride": 3
                    },
                )

                save_model.save_checkpoint(checkpoint,
                                           file_name="epoch_{}.pt".format(i),
                                           append_time=False)
                test(model_G, model_D, encoder, test_data, use_cuda)

            if scheduler_G is not None:
                scheduler_G.step()
                scheduler_D.step()
                scheduler_R.step()
Ejemplo n.º 2
0
        loss.backward()
        optimizer.step()

        # Print loss
        iterator.set_description(
            'Epoch [{epoch}/{epochs}] :: Train Loss {loss:.4f}'.format(epoch=epoch, epochs=args.epochs,
                                                                       loss=loss.item()))
        writer.add_scalar('train/{loss_type}/total_loss_iter', loss.item(), epoch * len(dataloader) + i)

        if i % (len(dataloader) // 10):
            summary.visualize_image(writer, data, seg, epoch * len(dataloader) + i)

    if not epoch % 1:
        saver.save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_pred': best_pred,
        }, is_best)

        if args.validate:
            # Validate
            with torch.no_grad():
                model.eval()
                iterator = tqdm(dataloader_val,
                                leave=True,
                                dynamic_ncols=True,
                                desc='Validation ::')
                input = dataset_val.img[
                        dataset_val.effective_lable_idx[0][0]:dataset_val.effective_lable_idx[0][1],
                        dataset_val.effective_lable_idx[1][0]:dataset_val.effective_lable_idx[1][1],
                        dataset_val.effective_lable_idx[2][0]:dataset_val.effective_lable_idx[2][1]
Ejemplo n.º 3
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(args.logdir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        dltrain = DLDataset('trainval', "./data/pascal_voc_seg/tfrecord/")
        dlval = DLDataset('val', "./data/pascal_voc_seg/tfrecord/")
        # dltrain = DLDataset('trainval', "./data/pascal_voc_seg/VOCdevkit/VOC2012/")
        # dlval = DLDataset('val', "./data/pascal_voc_seg/VOCdevkit/VOC2012/")
        self.train_loader = DataLoader(dltrain,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.workers,
                                       pin_memory=True)
        self.val_loader = DataLoader(dlval,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True)

        # Define network
        model = Deeplab()

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = nn.CrossEntropyLoss(ignore_index=255).cuda()
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(21)
        # Define lr scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer)

        # Using cuda
        # if args.cuda:
        # self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, (image, target) in enumerate(tbar):
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            self.optimizer.zero_grad()

            output = self.model(image)
            loss = self.criterion(output, target.long())
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            # if i % (num_img_tr // 10) == 0:
            if i % 10 == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)
        self.scheduler.step(train_loss)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, (image, target) in enumerate(tbar):

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)

            loss = self.criterion(output, target.long())
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)