Ejemplo n.º 1
0
def train_model(train_dl,model):
    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    n_total_steps = len(train_dl)
    #print(n_total_steps)
    for epoch in range(num_epoch):
        if epoch % 2== 0:
            checkpoint = {'state_dict': model.state_dict(),'optimizer':optimizer.state_dict()}
            save_checkpoint(checkpoint)

        for i, (inputs,targets) in enumerate(train_dl):
            inputs = inputs.to(device)
            targets = targets.to(device)
            #print(targets)
            # Forward Pass
            yhat = model(inputs)
            loss = criterion(yhat,targets)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if (i+1) % 2 == 0:
            print("Epoch:",epoch+1/num_epoch,"Step:", i+1/n_total_steps, "Loss:",loss.item())
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

        # os.mkdir(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    for epoch in range(args.epochs):
        scheduler.step(epoch)
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd)
        test_loss, test_acc = test(test_loader, model, criterion,
                                   args.noise_sd)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
Ejemplo n.º 3
0
def save_checkpoints(epoch: int, model: Module, optimizer: SGD, loss: _Loss,
                     path: str):
    torch.save(
        {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, path)
Ejemplo n.º 4
0
def main():
    cudnn.benchmark = True
    batch_size = Config.gpu_count * Config.image_per_gpu
    EPOCHS = Config.epoch

    workers = Config.workers
    global best_val_acc, best_test_acc
    Config.distributed = Config.gpu_count > 4 # TODO!

    model = set_model()
    #if Config.gpu is not None:
    model = model.cuda()
    if Config.gpu_count > 1:
        model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda()
    #weights = torch.FloatTensor(np.array([0.7, 0.3])).cuda()
    #criterion = WeightCrossEntropy(num_classes=Config.out_class, weight=weights).cuda()
    #criterion = LGMLoss(num_classes=Config.out_class, feat_dim=128).cuda()
    optimizer = SGD(model.parameters(), lr=Config.lr, momentum=0.9,nesterov=True, weight_decay=0.0001)
    #optimizer = Adam(model.parameters())
    train_dir = os.path.join(DATA_DIR, 'train', '40X')
    val_dir = os.path.join(DATA_DIR, 'val', '40X')
    test_dir = os.path.join(DATA_DIR, 'test', '40X')

    TRANSFORM_IMG = transforms.Compose([
        transforms.Resize((256, 256)),
        #ImageTransform(),
        #lambda x: PIL.Image.fromarray(x),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.2, 0.2, 0.2])
    ])

    train_loader = DataLoader(ImageFolder(root=train_dir, transform=TRANSFORM_IMG),
                              batch_size=batch_size, shuffle=True, pin_memory=True,
                              num_workers=workers)
    val_loader = DataLoader(ImageFolder(root=test_dir, transform=TRANSFORM_IMG),
                            batch_size=batch_size, shuffle=True, pin_memory=True,
                            num_workers=workers)
    #test_loader = DataLoader(ImageFolder(root=test_dir, transform=TRANSFORM_IMG),
     #                        batch_size=batch_size, shuffle=True, pin_memory=True,
      #                       num_workers=workers)

    for epoch in range(EPOCHS):
        adjust_learing_rate(optimizer, epoch)
        train_losses, train_acc = train_epoch(train_loader, model, criterion, optimizer, epoch)
        val_losses, val_acc = validate(val_loader, model, criterion)
        is_best = val_acc.avg > best_val_acc
        print('>>>>>>>>>>>>>>>>>>>>>>')
        print('Epoch: {} train loss: {}, train acc: {}, valid loss: {}, valid acc: {}'.format(epoch, train_losses.avg, train_acc.avg,
                                                                                    val_losses.avg, val_acc.avg))
        print('>>>>>>>>>>>>>>>>>>>>>>')
        save_checkpoint({'epoch': epoch + 1,
                         'state_dict': model.state_dict(),
                         'best_val_acc': best_val_acc,
                         'optimizer': optimizer.state_dict(),}, is_best)
Ejemplo n.º 5
0
def main(args):
    run = RunManager(args,
                     ignore=('device', 'evaluate', 'no_cuda'),
                     main='model')
    print(run)

    train_dataset = DatasetFolder('data/train',
                                  load_sample, ('.npy', ),
                                  transform=normalize_sample)
    val_dataset = DatasetFolder('data/val',
                                load_sample, ('.npy', ),
                                transform=normalize_sample)

    print(train_dataset)
    print(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=8,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            num_workers=8)

    if args.model == '1d-conv':
        model = RectCNN(282)
    else:
        model = PaperCNN()
    model = model.double().to(args.device)

    optimizer = SGD(model.parameters(), lr=1e-2)

    # evaluate(val_loader, model, args)
    best = 0
    progress = trange(1, args.epochs)
    for epoch in progress:
        progress.set_description('TRAIN [CurBestAcc={:.2%}]'.format(best))
        train(train_loader, model, optimizer, args)
        progress.set_description('EVAL [CurBestAcc={:.2%}]'.format(best))
        metrics = evaluate(val_loader, model, args)

        is_best = metrics['acc'] > best
        best = max(metrics['acc'], best)
        if is_best:
            run.save_checkpoint(
                {
                    'epoch': epoch,
                    'params': vars(args),
                    'model': model.state_dict(),
                    'optim': optimizer.state_dict(),
                    'metrics': metrics
                }, is_best)

        metrics.update({'epoch': epoch})
        run.pushLog(metrics)
Ejemplo n.º 6
0
def main():
    if args['gpu']:
        os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']

    if not os.path.exists(args['outdir']):
        os.mkdir(args['outdir'])

    train_loader, test_loader = loaddata(args)

    if torch.cuda.is_available():
        model = loadmodel(args)
        model = model.cuda()

    logfilename = os.path.join(args['outdir'], 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args['lr'],
                    momentum=args['momentum'],
                    weight_decay=args['weight_decay'])
    scheduler = StepLR(optimizer,
                       step_size=args['lr_step_size'],
                       gamma=args['gamma'])

    for epoch in range(args['epochs']):
        scheduler.step(epoch)
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args['noise_sd'])
        test_loss, test_acc = test(test_loader, model, criterion,
                                   args['noise_sd'])
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        torch.save(
            {
                'epoch': epoch + 1,
                'dataset': args['dataset'],
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args['outdir'], 'checkpoint.pth.tar'))
Ejemplo n.º 7
0
def main_train_worker(args):
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    print("=> creating model '{}'".format(args.arch))
    network = MetaLearnerModelBuilder.construct_cifar_model(
        args.arch, args.dataset)
    model_path = '{}/train_pytorch_model/real_image_model/{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
        PY_ROOT, args.dataset, args.arch, args.epochs, args.lr,
        args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    print("after train, model will be saved to {}".format(model_path))
    network.cuda()
    image_classifier_loss = nn.CrossEntropyLoss().cuda()
    optimizer = SGD(network.parameters(),
                    args.lr,
                    weight_decay=args.weight_decay)
    cudnn.benchmark = True
    train_loader = DataLoaderMaker.get_img_label_data_loader(
        args.dataset, args.batch_size, True)
    val_loader = DataLoaderMaker.get_img_label_data_loader(
        args.dataset, args.batch_size, False)

    for epoch in range(0, args.epochs):
        # adjust_learning_rate(optimizer, epoch, args)
        # train_simulate_grad_mode for one epoch
        train(train_loader, network, image_classifier_loss, optimizer, epoch,
              args)
        # evaluate_accuracy on validation set
        validate(val_loader, network, image_classifier_loss, args)
        # remember best acc@1 and save checkpoint
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            filename=model_path)
Ejemplo n.º 8
0
def train(cont=False):

    # for tensorboard tracking
    logger = get_logger()
    logger.info("(1) Initiating Training ... ")
    logger.info("Training on device: {}".format(device))
    writer = SummaryWriter()

    # init model
    aux_layers = None
    if net == "SETR-PUP":
        aux_layers, model = get_SETR_PUP()
    elif net == "SETR-MLA":
        aux_layers, model = get_SETR_MLA()
    elif net == "TransUNet-Base":
        model = get_TransUNet_base()
    elif net == "TransUNet-Large":
        model = get_TransUNet_large()
    elif net == "UNet":
        model = UNet(CLASS_NUM)

    # prepare dataset
    cluster_model = get_clustering_model(logger)
    train_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="train",
                                     cluster_model=cluster_model)
    valid_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="val",
                                     cluster_model=cluster_model)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False)

    logger.info("(2) Dataset Initiated. ")

    # optimizer
    epochs = epoch_num if epoch_num > 0 else iteration_num // len(
        train_loader) + 1
    optim = SGD(model.parameters(),
                lr=lrate,
                momentum=momentum,
                weight_decay=wdecay)
    # optim = Adam(model.parameters(), lr=lrate)
    scheduler = lr_scheduler.MultiStepLR(
        optim, milestones=[int(epochs * fine_tune_ratio)], gamma=0.1)

    cur_epoch = 0
    best_loss = float('inf')
    epochs_since_improvement = 0

    # for continue training
    if cont:
        model, optim, cur_epoch, best_loss = load_ckpt_continue_training(
            best_ckpt_src, model, optim, logger)
        logger.info("Current best loss: {0}".format(best_loss))
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i in range(cur_epoch):
                scheduler.step()
    else:
        model = nn.DataParallel(model)
        model = model.to(device)

    logger.info("(3) Model Initiated ... ")
    logger.info("Training model: {}".format(net) + ". Training Started.")

    # loss
    ce_loss = CrossEntropyLoss()
    if use_dice_loss:
        dice_loss = DiceLoss(CLASS_NUM)

    # loop over epochs
    iter_count = 0
    epoch_bar = tqdm.tqdm(total=epochs,
                          desc="Epoch",
                          position=cur_epoch,
                          leave=True)
    logger.info("Total epochs: {0}. Starting from epoch {1}.".format(
        epochs, cur_epoch + 1))

    for e in range(epochs - cur_epoch):
        epoch = e + cur_epoch

        # Training.
        model.train()
        trainLossMeter = LossMeter()
        train_batch_bar = tqdm.tqdm(total=len(train_loader),
                                    desc="TrainBatch",
                                    position=0,
                                    leave=True)

        for batch_num, (orig_img, mask_img) in enumerate(train_loader):
            orig_img, mask_img = orig_img.float().to(
                device), mask_img.float().to(device)

            if net == "TransUNet-Base" or net == "TransUNet-Large":
                pred = model(orig_img)
            elif net == "SETR-PUP" or net == "SETR-MLA":
                if aux_layers is not None:
                    pred, _ = model(orig_img)
                else:
                    pred = model(orig_img)
            elif net == "UNet":
                pred = model(orig_img)

            loss_ce = ce_loss(pred, mask_img[:].long())
            if use_dice_loss:
                loss_dice = dice_loss(pred, mask_img, softmax=True)
                loss = 0.5 * (loss_ce + loss_dice)
            else:
                loss = loss_ce

            # Backward Propagation, Update weight and metrics
            optim.zero_grad()
            loss.backward()
            optim.step()

            # update learning rate
            for param_group in optim.param_groups:
                orig_lr = param_group['lr']
                param_group['lr'] = orig_lr * (1.0 -
                                               iter_count / iteration_num)**0.9
            iter_count += 1

            # Update loss
            trainLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Epoch: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(train_loader), loss=trainLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Train_Loss_{0}'.format(tensorboard_freq),
                    trainLossMeter.avg,
                    epoch * (len(train_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            train_batch_bar.update(1)

        writer.add_scalar('Train_Loss_epoch', trainLossMeter.avg, epoch)

        # Validation.
        model.eval()
        validLossMeter = LossMeter()
        valid_batch_bar = tqdm.tqdm(total=len(valid_loader),
                                    desc="ValidBatch",
                                    position=0,
                                    leave=True)
        with torch.no_grad():
            for batch_num, (orig_img, mask_img) in enumerate(valid_loader):
                orig_img, mask_img = orig_img.float().to(
                    device), mask_img.float().to(device)

                if net == "TransUNet-Base" or net == "TransUNet-Large":
                    pred = model(orig_img)
                elif net == "SETR-PUP" or net == "SETR-MLA":
                    if aux_layers is not None:
                        pred, _ = model(orig_img)
                    else:
                        pred = model(orig_img)
                elif net == "UNet":
                    pred = model(orig_img)

                loss_ce = ce_loss(pred, mask_img[:].long())
                if use_dice_loss:
                    loss_dice = dice_loss(pred, mask_img, softmax=True)
                    loss = 0.5 * (loss_ce + loss_dice)
                else:
                    loss = loss_ce

                # Update loss
                validLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Validation: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(valid_loader), loss=validLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Valid_Loss_{0}'.format(tensorboard_freq),
                    validLossMeter.avg,
                    epoch * (len(valid_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            valid_batch_bar.update(1)

        valid_loss = validLossMeter.avg
        writer.add_scalar('Valid_Loss_epoch', valid_loss, epoch)
        logger.info("Validation Loss of epoch [{0}/{1}]: {2}\n".format(
            epoch + 1, epochs, valid_loss))

        # update optim scheduler
        scheduler.step()

        # save checkpoint
        is_best = valid_loss < best_loss
        best_loss_tmp = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            logger.info("Epochs since last improvement: %d\n" %
                        (epochs_since_improvement, ))
            if epochs_since_improvement == early_stop_tolerance:
                break  # early stopping.
        else:
            epochs_since_improvement = 0
            state = {
                'epoch': epoch,
                'loss': best_loss_tmp,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
            }
            torch.save(state, ckpt_src)
            logger.info("Checkpoint updated.")
            best_loss = best_loss_tmp
        epoch_bar.update(1)
    writer.close()
Ejemplo n.º 9
0
def train(start_path, beta):

    # prepare hyper-parameters

    seed = 42

    cuda_enabled = True
    cuda_deterministic = False

    batch_size = 2048
    num_workers = 2

    shared = False

    stochastic = False
    kkt_momentum = 0.0
    create_graph = False
    grad_correction = False
    shift = 0.0
    tol = 1e-5
    damping = 0.1
    maxiter = 50

    lr = 0.1
    momentum = 0.0
    weight_decay = 0.0

    num_steps = 10

    verbose = False

    # prepare path

    ckpt_name = start_path.name.split('.')[0]
    root_path = Path(__file__).resolve().parent
    dataset_path = root_path / 'MultiMNIST'
    ckpt_path = root_path / 'cpmtl' / ckpt_name

    if not start_path.is_file():
        raise RuntimeError('Pareto solutions not found.')

    root_path.mkdir(parents=True, exist_ok=True)
    dataset_path.mkdir(parents=True, exist_ok=True)
    ckpt_path.mkdir(parents=True, exist_ok=True)

    # fix random seed

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda_enabled and torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # prepare device

    if cuda_enabled and torch.cuda.is_available():
        import torch.backends.cudnn as cudnn
        device = torch.device('cuda')
        if cuda_deterministic:
            cudnn.benchmark = False
            cudnn.deterministic = True
        else:
            cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # prepare dataset

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    trainset = MultiMNIST(dataset_path,
                          train=True,
                          download=True,
                          transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers)

    testset = MultiMNIST(dataset_path,
                         train=False,
                         download=True,
                         transform=transform)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)

    # prepare network

    network = MultiLeNet()
    network.to(device)

    # initialize network

    start_ckpt = torch.load(start_path, map_location='cpu')
    network.load_state_dict(start_ckpt['state_dict'])

    # prepare losses

    criterion = F.cross_entropy
    closures = [
        lambda n, l, t: criterion(l[0], t[:, 0]),
        lambda n, l, t: criterion(l[1], t[:, 1])
    ]

    # prepare HVP solver

    hvp_solver = VisionHVPSolver(network,
                                 device,
                                 trainloader,
                                 closures,
                                 shared=shared)
    hvp_solver.set_grad(batch=False)
    hvp_solver.set_hess(batch=True)

    # prepare KKT solver

    kkt_solver = MINRESKKTSolver(network,
                                 hvp_solver,
                                 device,
                                 stochastic=stochastic,
                                 kkt_momentum=kkt_momentum,
                                 create_graph=create_graph,
                                 grad_correction=grad_correction,
                                 shift=shift,
                                 tol=tol,
                                 damping=damping,
                                 maxiter=maxiter)

    # prepare optimizer

    optimizer = SGD(network.parameters(),
                    lr=lr,
                    momentum=momentum,
                    weight_decay=weight_decay)

    # first evaluation

    losses, tops = evaluate(network, testloader, device, closures,
                            f'{ckpt_name}')

    # prepare utilities
    top_trace = TopTrace(len(closures))
    top_trace.print(tops, show=False)

    beta = beta.to(device)

    # training

    for step in range(1, num_steps + 1):

        network.train(True)
        optimizer.zero_grad()
        kkt_solver.backward(beta, verbose=verbose)
        optimizer.step()

        losses, tops = evaluate(network, testloader, device, closures,
                                f'{ckpt_name}: {step}/{num_steps}')

        top_trace.print(tops)

        ckpt = {
            'state_dict': network.state_dict(),
            'optimizer': optimizer.state_dict(),
            'beta': beta,
        }
        record = {'losses': losses, 'tops': tops}
        ckpt['record'] = record
        torch.save(ckpt, ckpt_path / f'{step:d}.pth')

    hvp_solver.close()
Ejemplo n.º 10
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.patience = config.patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        self.optimizer = SGD(
            self.model.parameters(),
            lr=self.lr,
            momentum=self.momentum,
        )
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           'min',
                                           patience=self.patience)

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # reduce lr if validation loss plateaus
            self.scheduler.step(valid_loss)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.mean(-log_pi * adjusted_reward)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.data[0], x.size()[0])
                accs.update(acc.data[0], x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data[0], acc.data[0])))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.mean(-log_pi * adjusted_reward)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data[0], x.size()[0])
            accs.update(acc.data[0], x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        print('[*] Test Acc: {}/{} ({:.2f}%)'.format(correct, self.num_test,
                                                     perc))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'] + 1, ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch'] + 1))
Ejemplo n.º 11
0
            loss.backward()
            optimizer.step()

            trian_loss += loss.item() / len(train_loader)

            _, y_pred = torch.max(y_pred, 1)
            trian_acc += (y_pred == batch_y).sum().item() / len(y)

        train_loss_temp.append(trian_loss)
        train_acc_temp.append(trian_acc)

        if epoch % checkpoint == 0:
            torch.save(model.state_dict(),
                       f"./data/prob{prob_num}_model_ckpt_{epoch}.bin")
            torch.save(optimizer.state_dict(),
                       f"./data/prob{prob_num}_optimizer_ckpt_{epoch}.bin")

        model = model.eval()
        with torch.no_grad():
            y_pred = model(X_valid)
            loss = loss_fn(y_pred, y_valid)

            _, y_pred = torch.max(y_pred, 1)
            acc = (y_pred == y_valid).sum().item() / len(y_valid)
            valid_loss_temp.append(loss)
            valid_acc_temp.append(acc)

    train_loss_all.append(train_loss_temp)
    train_acc_all.append(train_acc_temp)
    valid_loss_all.append(valid_loss_temp)
Ejemplo n.º 12
0
class Trainer:
    def __init__(self,
                 model: nn.Module,
                 dataset_root: str,
                 summary_writer: SummaryWriter,
                 device: Device,
                 batch_size: int = 128,
                 cc_loss: bool = False):
        # load train/test splits of SALICON dataset
        train_dataset = Salicon(dataset_root + "train.pkl")
        test_dataset = Salicon(dataset_root + "val.pkl")

        self.train_loader = DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=batch_size,
            pin_memory=True,
            num_workers=1,
        )
        self.val_loader = DataLoader(
            test_dataset,
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
            pin_memory=True,
        )
        self.model = model.to(device)
        self.device = device
        if cc_loss:
            self.criterion = CCLoss
        else:
            self.criterion = nn.MSELoss()
        self.optimizer = SGD(self.model.parameters(),
                             lr=0.03,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)
        self.summary_writer = summary_writer
        self.step = 0

    def train(self,
              epochs: int,
              val_frequency: int,
              log_frequency: int = 5,
              start_epoch: int = 0):
        lrs = np.linspace(0.03, 0.0001, epochs)
        for epoch in range(start_epoch, epochs):
            self.model.train()
            for batch, gts in self.train_loader:
                # LR decay
                # need to update learning rate between 0.03 and 0.0001 (according to paper)
                optimstate = self.optimizer.state_dict()
                self.optimizer = SGD(self.model.parameters(),
                                     lr=lrs[epoch],
                                     momentum=0.9,
                                     weight_decay=0.0005,
                                     nesterov=True)
                self.optimizer.load_state_dict(optimstate)

                self.optimizer.zero_grad()
                # load batch to device
                batch = batch.to(self.device)
                gts = gts.to(self.device)

                # train step
                step_start_time = time.time()
                output = self.model.forward(batch)
                loss = self.criterion(output, gts)
                loss.backward()
                self.optimizer.step()

                # log step
                if ((self.step + 1) % log_frequency) == 0:
                    step_time = time.time() - step_start_time
                    self.log_metrics(epoch, loss, step_time)
                    self.print_metrics(epoch, loss, step_time)

                # count steps
                self.step += 1

            # log epoch
            self.summary_writer.add_scalar("epoch", epoch, self.step)

            # validate
            if ((epoch + 1) % val_frequency) == 0:
                self.validate()
                self.model.train()
            if (epoch + 1) % 10 == 0:
                save(self.model, "checkp_model.pkl")

    def print_metrics(self, epoch, loss, step_time):
        epoch_step = self.step % len(self.train_loader)
        print(f"epoch: [{epoch}], "
              f"step: [{epoch_step}/{len(self.train_loader)}], "
              f"batch loss: {loss:.5f}, "
              f"step time: {step_time:.5f}")

    def log_metrics(self, epoch, loss, step_time):
        self.summary_writer.add_scalar("epoch", epoch, self.step)

        self.summary_writer.add_scalars("loss", {"train": float(loss.item())},
                                        self.step)
        self.summary_writer.add_scalar("time/data", step_time, self.step)

    def validate(self):
        results = {"preds": [], "gts": []}
        total_loss = 0
        self.model.eval()

        # No need to track gradients for validation, we're not optimizing.
        with no_grad():
            for batch, gts in self.val_loader:
                batch = batch.to(self.device)
                gts = gts.to(self.device)
                output = self.model(batch)
                loss = self.criterion(output, gts)
                total_loss += loss.item()
                preds = output.cpu().numpy()
                results["preds"].extend(list(preds))
                results["gts"].extend(list(gts.cpu().numpy()))

        average_loss = total_loss / len(self.val_loader)

        self.summary_writer.add_scalars("loss", {"test": average_loss},
                                        self.step)
        print(f"validation loss: {average_loss:.5f}")
Ejemplo n.º 13
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_transform = T.Compose([
        T.RandomRotation(args.rotation),
        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.GaussianBlur(),
        T.ToTensor(), normalize
    ])
    val_transform = T.Compose(
        [T.Resize(args.image_size),
         T.ToTensor(), normalize])
    image_size = (args.image_size, args.image_size)
    heatmap_size = (args.heatmap_size, args.heatmap_size)
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_source_dataset = source_dataset(root=args.source_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_source_loader = DataLoader(val_source_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(root=args.target_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    print("Source train:", len(train_source_loader))
    print("Target train:", len(train_target_loader))
    print("Source test:", len(val_source_loader))
    print("Target test:", len(val_target_loader))

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    upsampling = Upsampling(backbone.out_features)
    num_keypoints = train_source_dataset.num_keypoints
    model = RegDAPoseResNet(backbone,
                            upsampling,
                            256,
                            num_keypoints,
                            num_head_layers=args.num_head_layers,
                            finetune=True).to(device)
    # define loss function
    criterion = JointsKLLoss()
    pseudo_label_generator = PseudoLabelGenerator(num_keypoints,
                                                  args.heatmap_size,
                                                  args.heatmap_size)
    regression_disparity = RegressionDisparity(pseudo_label_generator,
                                               JointsKLLoss(epsilon=1e-7))

    # define optimizer and lr scheduler
    optimizer_f = SGD([
        {
            'params': backbone.parameters(),
            'lr': 0.1
        },
        {
            'params': upsampling.parameters(),
            'lr': 0.1
        },
    ],
                      lr=0.1,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h = SGD(model.head.parameters(),
                      lr=1.,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h_adv = SGD(model.head_adv.parameters(),
                          lr=1.,
                          momentum=args.momentum,
                          weight_decay=args.wd,
                          nesterov=True)
    lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x))**(
        -args.lr_decay)
    lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)
    lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)
    lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)
    start_epoch = 0

    if args.resume is None:
        if args.pretrain is None:
            # first pretrain the backbone and upsampling
            print("Pretraining the model on source domain.")
            args.pretrain = logger.get_checkpoint_path('pretrain')
            pretrained_model = PoseResNet(backbone, upsampling, 256,
                                          num_keypoints, True).to(device)
            optimizer = SGD(pretrained_model.get_parameters(lr=args.lr),
                            momentum=args.momentum,
                            weight_decay=args.wd,
                            nesterov=True)
            lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)
            best_acc = 0
            for epoch in range(args.pretrain_epochs):
                lr_scheduler.step()
                print(lr_scheduler.get_lr())

                pretrain(train_source_iter, pretrained_model, criterion,
                         optimizer, epoch, args)
                source_val_acc = validate(val_source_loader, pretrained_model,
                                          criterion, None, args)

                # remember best acc and save checkpoint
                if source_val_acc['all'] > best_acc:
                    best_acc = source_val_acc['all']
                    torch.save({'model': pretrained_model.state_dict()},
                               args.pretrain)
                print("Source: {} best: {}".format(source_val_acc['all'],
                                                   best_acc))

        # load from the pretrained checkpoint
        pretrained_dict = torch.load(args.pretrain,
                                     map_location='cpu')['model']
        model_dict = model.state_dict()
        # remove keys from pretrained dict that doesn't appear in model dict
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model.load_state_dict(pretrained_dict, strict=False)
    else:
        # optionally resume from a checkpoint
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer_f.load_state_dict(checkpoint['optimizer_f'])
        optimizer_h.load_state_dict(checkpoint['optimizer_h'])
        optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])
        lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])
        lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])
        lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])
        start_epoch = checkpoint['epoch'] + 1

    # define visualization function
    tensor_to_image = Compose([
        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToPILImage()
    ])

    def visualize(image, keypoint2d, name, heatmaps=None):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            keypoint2d (tensor): keypoints in shape K x 2
            name: name of the saving image
        """
        train_source_dataset.visualize(
            tensor_to_image(image), keypoint2d,
            logger.get_image_path("{}.jpg".format(name)))

    if args.phase == 'test':
        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize, args)
        print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'],
                                                       target_val_acc['all']))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))
        return

    # start training
    best_acc = 0
    print("Start regression domain adaptation.")
    for epoch in range(start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(),
              lr_scheduler_h_adv.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, model, criterion,
              regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
              lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch,
              visualize if args.debug else None, args)

        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize if args.debug else None, args)

        # remember best acc and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'optimizer_f': optimizer_f.state_dict(),
                'optimizer_h': optimizer_h.state_dict(),
                'optimizer_h_adv': optimizer_h_adv.state_dict(),
                'lr_scheduler_f': lr_scheduler_f.state_dict(),
                'lr_scheduler_h': lr_scheduler_h.state_dict(),
                'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if target_val_acc['all'] > best_acc:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
            best_acc = target_val_acc['all']
        print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(
            source_val_acc['all'], target_val_acc['all'], best_acc))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))

    logger.close()
Ejemplo n.º 14
0
class Trainer(object):
    def __init__(self, args):
        super(Trainer, self).__init__()
        train_transform = transforms.Compose([
            transforms.Resize((args.scale_size, args.scale_size)),
            transforms.RandomChoice([
                transforms.RandomCrop(640),
                transforms.RandomCrop(576),
                transforms.RandomCrop(512),
                transforms.RandomCrop(384),
                transforms.RandomCrop(320)
            ]),
            transforms.Resize((args.crop_size, args.crop_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        train_dataset = MLDataset(args.train_path, args.label_path,
                                  train_transform)
        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True,
                                       drop_last=True)
        val_transform = transforms.Compose([
            transforms.Resize((args.scale_size, args.scale_size)),
            transforms.CenterCrop(args.crop_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        val_dataset = MLDataset(args.val_path, args.label_path, val_transform)
        self.val_loader = DataLoader(dataset=val_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

        self.model = model_factory[args.model](args, args.num_classes)
        self.model.cuda()

        trainable_parameters = filter(lambda param: param.requires_grad,
                                      self.model.parameters())
        if args.optimizer == 'Adam':
            self.optimizer = Adam(trainable_parameters, lr=args.lr)
        elif args.optimizer == 'SGD':
            self.optimizer = SGD(trainable_parameters, lr=args.lr)

        self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                           mode='max',
                                                           patience=2,
                                                           verbose=True)
        if args.loss == 'BCElogitloss':
            self.criterion = nn.BCEWithLogitsLoss()
        elif args.loss == 'tencentloss':
            self.criterion = TencentLoss(args.num_classes)
        elif args.loss == 'focalloss':
            self.criterion = FocalLoss()
        self.early_stopping = EarlyStopping(patience=5)

        self.voc12_mAP = VOC12mAP(args.num_classes)
        self.average_loss = AverageLoss()
        self.average_topk_meter = TopkAverageMeter(args.num_classes,
                                                   topk=args.topk)
        self.average_threshold_meter = ThresholdAverageMeter(
            args.num_classes, threshold=args.threshold)

        self.args = args
        self.global_step = 0
        self.writer = SummaryWriter(log_dir=args.log_dir)

    def run(self):
        s_epoch = 0
        if self.args.resume:
            checkpoint = torch.load(self.args.ckpt_latest_path)
            s_epoch = checkpoint['epoch']
            self.global_step = checkpoint['global_step']
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
            self.early_stopping.best_score = checkpoint['best_score']
            print('loading checkpoint success (epoch {})'.format(s_epoch))

        for epoch in range(s_epoch, self.args.max_epoch):
            self.train(epoch)
            save_dict = {
                'epoch': epoch + 1,
                'global_step': self.global_step,
                'model_state_dict': self.model.state_dict(),
                'optim_state_dict': self.optimizer.state_dict(),
                'best_score': self.early_stopping.best_score
            }
            torch.save(save_dict, self.args.ckpt_latest_path)

            mAP = self.validation(epoch)
            self.lr_scheduler.step(mAP)
            is_save, is_terminate = self.early_stopping(mAP)
            if is_terminate:
                break
            if is_save:
                torch.save(self.model.state_dict(), self.args.ckpt_best_path)

    def train(self, epoch):
        self.model.train()
        if self.args.model == 'ssgrl':
            self.model.resnet_101.eval()
            self.model.resnet_101.layer4.train()
        for _, batch in enumerate(self.train_loader):
            x, y = batch[0].cuda(), batch[1].cuda()
            pred_y = self.model(x)
            loss = self.criterion(pred_y, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.global_step % 400 == 0:
                self.writer.add_scalar('Loss/train', loss, self.global_step)
                print('TRAIN [epoch {}] loss: {:4f}'.format(epoch, loss))

            self.global_step += 1

    def validation(self, epoch):
        self.model.eval()
        self.voc12_mAP.reset()
        self.average_loss.reset()
        self.average_topk_meter.reset()
        self.average_threshold_meter.reset()
        with torch.no_grad():
            for _, batch in enumerate(self.val_loader):
                x, y = batch[0].cuda(), batch[1].cuda()
                pred_y = self.model(x)
                loss = self.criterion(pred_y, y)

                y = y.cpu().numpy()
                pred_y = pred_y.cpu().numpy()
                loss = loss.cpu().numpy()
                self.voc12_mAP.update(pred_y, y)
                self.average_loss.update(loss, x.size(0))
                self.average_topk_meter.update(pred_y, y)
                self.average_threshold_meter.update(pred_y, y)

        _, mAP = self.voc12_mAP.compute()
        mLoss = self.average_loss.compute()
        self.average_topk_meter.compute()
        self.average_threshold_meter.compute()
        self.writer.add_scalar('Loss/val', mLoss, self.global_step)
        self.writer.add_scalar('mAP/val', mAP, self.global_step)

        print("Validation [epoch {}] mAP: {:.4f} loss: {:.4f}".format(
            epoch, mAP, mLoss))
        return mAP
Ejemplo n.º 15
0
        optimizer.step()
        print(torch.cuda.max_memory_allocated(device=0))
        print(output)
        with open('./Logs/log_triplet_new.txt', 'a') as f:
            val_list = [
                epoch + 1, batch_idx,
                float(output),
                float(triplet_loss_sum)
            ]
            log = '\t'.join(str(value) for value in val_list)
            f.writelines(log + '\n')

    avg_triplet_loss = triplet_loss_sum / batches_per_epoch

    with open('./Logs/log_triplet_new.txt', 'a') as f:
        val_list = ['FINAL', epoch + 1, float(avg_triplet_loss)]
        log = '\t'.join(str(value) for value in val_list)
        f.writelines(log + '\n')

    print('Epoch {}:\tAverage Triplet Loss: {:.4f}\t'.format(
        epoch + 1, avg_triplet_loss))

torch.save(
    {
        'epoch': epoch,
        'model_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'avg_triplet_loss': avg_triplet_loss
    }, './Train_Checkpoints/' + 'checkpoint_' + str(epoch) + '_' +
    str(round(float(avg_triplet_loss), 4)) + '.tar')
Ejemplo n.º 16
0
class MetaFrameWork(object):
    def __init__(self, name='normal_all', train_num=1, source='GSIM',
                 target='C', network=Net, resume=True, dataset=DGMetaDataSets,
                 inner_lr=1e-3, outer_lr=5e-3, train_size=8, test_size=16, no_source_test=True, bn='torch'):
        super(MetaFrameWork, self).__init__()
        self.no_source_test = no_source_test
        self.train_num = train_num
        self.exp_name = name
        self.resume = resume

        self.inner_update_lr = inner_lr
        self.outer_update_lr = outer_lr
        self.network = network
        self.dataset = dataset
        self.train_size = train_size
        self.test_size = test_size
        self.source = source
        self.target = target
        self.bn = bn

        self.epoch = 1
        self.best_target_acc = 0
        self.best_target_acc_source = 0
        self.best_target_epoch = 1

        self.best_source_acc = 0
        self.best_source_acc_target = 0
        self.best_source_epoch = 0

        self.total_epoch = 120
        self.save_interval = 1
        self.save_path = Path(self.exp_name)
        self.init()

    def init(self):
        kwargs = {'bn': self.bn, 'output_stride': 8}
        self.backbone = nn.DataParallel(self.network(**kwargs)).cuda()
        kwargs.update({'pretrained': False})
        self.updated_net = nn.DataParallel(self.network(**kwargs)).cuda()
        self.ce = nn.CrossEntropyLoss(ignore_index=-1)
        self.nim = NaturalImageMeasure(nclass=19)

        batch_size = self.train_size
        workers = len(self.source) * 4

        dataloader = functools.partial(DataLoader, num_workers=workers, pin_memory=True, batch_size=batch_size, shuffle=True)
        self.train_loader = dataloader(self.dataset(mode='train', domains=self.source, force_cache=True))

        dataloader = functools.partial(DataLoader, num_workers=workers, pin_memory=True, batch_size=self.test_size, shuffle=False)
        self.source_val_loader = dataloader(self.dataset(mode='val', domains=self.source, force_cache=True))

        target_dataset, folder = get_dataset(self.target)
        self.target_loader = dataloader(target_dataset(root=ROOT + folder, mode='val'))
        self.target_test_loader = dataloader(target_dataset(root=ROOT + 'cityscapes', mode='test'))

        self.opt_old = SGD(self.backbone.parameters(), lr=self.outer_update_lr, momentum=0.9, weight_decay=5e-4)
        self.scheduler_old = PolyLR(self.opt_old, self.total_epoch, len(self.train_loader), 0, True, power=0.9)

        self.logger = get_logger('train', self.exp_name)
        self.log('exp_name : {}, train_num = {}, source domains = {}, target_domain = {}, lr : inner = {}, outer = {},'
                 'dataset : {}, net : {}, bn : {}\n'.
                 format(self.exp_name, self.train_num, self.source, self.target, self.inner_update_lr, self.outer_update_lr, self.dataset,
                        self.network, self.bn))
        self.log(self.exp_name + '\n')
        self.train_timer, self.test_timer = Timer(), Timer()

    def train(self, epoch, it, inputs):
        # imgs : batch x domains x C x H x W
        # targets : batch x domains x 1 x H x W
        imgs, targets = inputs
        B, D, C, H, W = imgs.size()
        meta_train_imgs = imgs.view(-1, C, H, W)
        meta_train_targets = targets.view(-1, 1, H, W)

        tr_logits = self.backbone(meta_train_imgs)[0]
        tr_logits = make_same_size(tr_logits, meta_train_targets)
        ds_loss = self.ce(tr_logits, meta_train_targets[:, 0])
        with torch.no_grad():
            self.nim(tr_logits, meta_train_targets)

        self.opt_old.zero_grad()
        ds_loss.backward()
        self.opt_old.step()
        self.scheduler_old.step(epoch, it)
        losses = {
            'dg': 0,
            'ds': ds_loss.item()
        }
        acc = {
            'iou': self.nim.get_res()[0]
        }
        return losses, acc, self.scheduler_old.get_lr(epoch, it)[0]

    def meta_train(self, epoch, it, inputs):
        # imgs : batch x domains x C x H x W
        # targets : batch x domains x 1 x H x W

        imgs, targets = inputs
        B, D, C, H, W = imgs.size()
        split_idx = np.random.permutation(D)
        i = np.random.randint(1, D)
        train_idx = split_idx[:i]
        test_idx = split_idx[i:]
        # train_idx = split_idx[:D // 2]
        # test_idx = split_idx[D // 2:]

        # self.print(split_idx, B, D, C, H, W)'
        meta_train_imgs = imgs[:, train_idx].reshape(-1, C, H, W)
        meta_train_targets = targets[:, train_idx].reshape(-1, 1, H, W)
        meta_test_imgs = imgs[:, test_idx].reshape(-1, C, H, W)
        meta_test_targets = targets[:, test_idx].reshape(-1, 1, H, W)

        # Meta-Train
        tr_logits = self.backbone(meta_train_imgs)[0]
        tr_logits = make_same_size(tr_logits, meta_train_targets)
        ds_loss = self.ce(tr_logits, meta_train_targets[:, 0])

        # Update new network
        self.opt_old.zero_grad()
        ds_loss.backward(retain_graph=True)
        updated_net = get_updated_network(self.backbone, self.updated_net, self.inner_update_lr).train().cuda()

        # Meta-Test
        te_logits = updated_net(meta_test_imgs)[0]
        # te_logits = test_res[0]
        te_logits = make_same_size(te_logits, meta_test_targets)
        dg_loss = self.ce(te_logits, meta_test_targets[:, 0])
        with torch.no_grad():
            self.nim(te_logits, meta_test_targets)

        # Update old network
        dg_loss.backward()
        self.opt_old.step()
        self.scheduler_old.step(epoch, it)
        losses = {
            'dg': dg_loss.item(),
            'ds': ds_loss.item()
        }
        acc = {
            'iou': self.nim.get_res()[0],
        }
        return losses, acc, self.scheduler_old.get_lr(epoch, it)[0]

    def do_train(self):
        if self.resume:
            self.load()

        self.writer = SummaryWriter(str(self.save_path / 'tensorboard'), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S'))
        self.log('Start epoch : {}\n'.format(self.epoch))

        for epoch in range(self.epoch, self.total_epoch + 1):
            loss_meters, acc_meters = MeterDicts(), MeterDicts(averaged=['iou'])
            self.nim.clear_cache()
            self.backbone.train()
            self.epoch = epoch
            with self.train_timer:
                for it, (paths, imgs, target) in enumerate(self.train_loader):
                    meta = (it + 1) % self.train_num == 0
                    if meta:
                        losses, acc, lr = self.meta_train(epoch - 1, it, to_cuda([imgs, target]))
                    else:
                        losses, acc, lr = self.train(epoch - 1, it, to_cuda([imgs, target]))

                    loss_meters.update_meters(losses, skips=['dg'] if not meta else [])
                    acc_meters.update_meters(acc)

                    self.print(self.get_string(epoch, it, loss_meters, acc_meters, lr, meta), end='')
                    self.tfb_log(epoch, it, loss_meters, acc_meters)
            self.print(self.train_timer.get_formatted_duration())
            self.log(self.get_string(epoch, it, loss_meters, acc_meters, lr, meta) + '\n')

            self.save('ckpt')
            if epoch % self.save_interval == 0:
                with self.test_timer:
                    city_acc = self.val(self.target_loader)
                    self.save_best(city_acc, epoch)

            total_duration = self.train_timer.duration + self.test_timer.duration
            self.print('Time Left : ' + self.train_timer.get_formatted_duration(total_duration * (self.total_epoch - epoch)) + '\n')

        self.log('Best city acc : \n  city : {}, origin : {}, epoch : {}\n'.format(
            self.best_target_acc, self.best_target_acc_source, self.best_target_epoch))
        self.log('Best origin acc : \n  city : {}, origin : {}, epoch : {}\n'.format(
            self.best_source_acc_target, self.best_source_acc, self.best_source_epoch))

    def save_best(self, city_acc, epoch):
        self.writer.add_scalar('acc/citys', city_acc, epoch)
        if not self.no_source_test:
            origin_acc = self.val(self.source_val_loader)
            self.writer.add_scalar('acc/origin', origin_acc, epoch)
        else:
            origin_acc = 0

        self.writer.flush()
        if city_acc > self.best_target_acc:
            self.best_target_acc = city_acc
            self.best_target_acc_source = origin_acc
            self.best_target_epoch = epoch
            self.save('best_city')

        if origin_acc > self.best_source_acc:
            self.best_source_acc = origin_acc
            self.best_source_acc_target = city_acc
            self.best_source_epoch = epoch
            self.save('best_origin')

    def val(self, dataset):
        self.backbone.eval()
        with torch.no_grad():
            self.nim.clear_cache()
            self.nim.set_max_len(len(dataset))
            for p, img, target in dataset:
                img, target = to_cuda(get_img_target(img, target))
                logits = self.backbone(img)[0]
                self.nim(logits, target)
        self.log('\nNormal validation : {}\n'.format(self.nim.get_acc()))
        if hasattr(dataset.dataset, 'format_class_iou'):
            self.log(dataset.dataset.format_class_iou(self.nim.get_class_acc()[0]) + '\n')
        return self.nim.get_acc()[0]

    def target_specific_val(self, loader):
        self.nim.clear_cache()
        self.nim.set_max_len(len(loader))
        # eval for dropout
        self.backbone.module.remove_dropout()
        self.backbone.module.not_track()
        for idx, (p, img, target) in enumerate(loader):
            if len(img.size()) == 5:
                B, D, C, H, W = img.size()
            else:
                B, C, H, W = img.size()
                D = 1
            img, target = to_cuda([img.reshape(B, D, C, H, W), target.reshape(B, D, 1, H, W)])
            for d in range(img.size(1)):
                img_d, target_d, = img[:, d], target[:, d]
                self.backbone.train()
                with torch.no_grad():
                    new_logits = self.backbone(img_d)[0]
                    self.nim(new_logits, target_d)

        self.backbone.module.recover_dropout()
        self.log('\nTarget specific validation : {}\n'.format(self.nim.get_acc()))
        if hasattr(loader.dataset, 'format_class_iou'):
            self.log(loader.dataset.format_class_iou(self.nim.get_class_acc()[0]) + '\n')
        return self.nim.get_acc()[0]

    def predict_target(self, load_path='best_city', color=False, train=False, output_path='predictions'):
        self.load(load_path)
        import skimage.io as skio
        dataset = self.target_test_loader

        output_path = Path(self.save_path / output_path)
        output_path.mkdir(exist_ok=True)

        if train:
            self.backbone.module.remove_dropout()
            self.backbone.train()
        else:
            self.backbone.eval()

        with torch.no_grad():
            self.nim.clear_cache()
            self.nim.set_max_len(len(dataset))
            for names, img, target in tqdm(dataset):
                img = to_cuda(img)
                logits = self.backbone(img)[0]
                logits = F.interpolate(logits, img.size()[2:], mode='bilinear', align_corners=True)
                preds = get_prediction(logits).cpu().numpy()
                if color:
                    trainId_preds = preds
                else:
                    trainId_preds = dataset.dataset.predict(preds)

                for pred, name in zip(trainId_preds, names):
                    file_name = name.split('/')[-1]
                    if color:
                        pred = class_map_2_color_map(pred).transpose(1, 2, 0).astype(np.uint8)
                    skio.imsave(str(output_path / file_name), pred)

    def get_string(self, epoch, it, loss_meters, acc_meters, lr, meta):
        string = '\repoch {:4}, iter : {:4}, '.format(epoch, it)
        for k, v in loss_meters.items():
            string += k + ' : {:.4f}, '.format(v.avg)
        for k, v in acc_meters.items():
            string += k + ' : {:.4f}, '.format(v.avg)

        string += 'lr : {:.6f}, meta : {}'.format(lr, meta)
        return string

    def log(self, strs):
        self.logger.info(strs)

    def print(self, strs, **kwargs):
        print(strs, **kwargs)

    def tfb_log(self, epoch, it, losses, acc):
        iteration = epoch * len(self.train_loader) + it
        for k, v in losses.items():
            self.writer.add_scalar('loss/' + k, v.val, iteration)
        for k, v in acc.items():
            self.writer.add_scalar('acc/' + k, v.val, iteration)

    def save(self, name='ckpt'):
        info = [self.best_source_acc, self.best_source_acc_target, self.best_source_epoch,
                self.best_target_acc, self.best_target_acc_source, self.best_target_epoch]
        dicts = {
            'backbone': self.backbone.state_dict(),
            'opt': self.opt_old.state_dict(),
            'epoch': self.epoch + 1,
            'best': self.best_target_acc,
            'info': info
        }
        self.print('Saving epoch : {}'.format(self.epoch))
        torch.save(dicts, self.save_path / '{}.pth'.format(name))

    def load(self, path=None, strict=False):
        if path is None:
            path = self.save_path / 'ckpt.pth'
        else:
            if 'pth' in path:
                path = path
            else:
                path = self.save_path / '{}.pth'.format(path)

        try:
            dicts = torch.load(path, map_location='cpu')
            msg = self.backbone.load_state_dict(dicts['backbone'], strict=strict)
            self.print(msg)
            if 'opt' in dicts:
                self.opt_old.load_state_dict(dicts['opt'])
            if 'epoch' in dicts:
                self.epoch = dicts['epoch']
            else:
                self.epoch = 1
            if 'best' in dicts:
                self.best_target_acc = dicts['best']
            if 'info' in dicts:
                self.best_source_acc, self.best_source_acc_target, self.best_source_epoch, \
                self.best_target_acc, self.best_target_acc_source, self.best_target_epoch = dicts['info']
            self.log('Loaded from {}, next epoch : {}, best_target : {}, best_epoch : {}\n'
                     .format(str(path), self.epoch, self.best_target_acc, self.best_target_epoch))
            return True
        except Exception as e:
            self.print(e)
            self.log('No ckpt found in {}\n'.format(str(path)))
            self.epoch = 1
            return False
Ejemplo n.º 17
0
def main(args):
    """ Main function

        Here, you should instantiate
        1) Dataset objects for training and test datasets 
        2) DataLoaders for training and testing 
        3) model  
        4) optimizer: SGD with initial learning rate 0.01 and momentum 0.9 
        5) cost function: use torch.nn.CrossEntropyLoss 

    """

    # write your codes here

    # Configuration
    mode = args.mode
    model_name = args.model
    options = args.o

    if mode == 'train':
        train_data_dir = args.d + '/train/'
    elif mode == 'test':
        test_data_dir = args.d + '/test/'
    elif mode == 'graph_compare':
        if model_name == 'LeNet5':
            models_name = [
                'LeNet5', 'LeNet5_insert_noise_s0.1_m0.0',
                'LeNet5_insert_noise_s0.2_m0.0',
                'LeNet5_insert_noise_s0.3_m0.0', 'LeNet5_weight_decay_0.0001',
                'LeNet5_weight_decay_0.001', 'LeNet5_weight_decay_0.01'
            ]
        elif model_name == 'CustomMLP_6':
            models_name = [
                'CustomMLP_6', 'CustomMLP_6_weight_decay_1e-05',
                'CustomMLP_6_weight_decay_0.0001',
                'CustomMLP_6_weight_decay_0.001'
            ]
        else:
            models_name = [
                'LeNet5', 'CustomMLP_1', 'CustomMLP_2', 'CustomMLP_3',
                'CustomMLP_4', 'CustomMLP_5', 'CustomMLP_6'
            ]

    model_path = args.m
    device = torch.device("cuda:" + str(args.cuda))
    lr = 0.01
    momentum = 0.6

    batch_size = args.b
    epoch = args.e

    use_ckpt = args.c

    if model_name == "CustomMLP_1":
        layer_option = [54, 47, 35, 10, 39]
    elif model_name == "CustomMLP_2":
        layer_option = [55, 35, 30, 34]
    elif model_name == "CustomMLP_3":
        layer_option = [55, 34, 33, 31]
    elif model_name == "CustomMLP_4":
        layer_option = [55, 41, 41]
    elif model_name == "CustomMLP_5":
        layer_option = [56, 51]
    elif model_name == "CustomMLP_6":
        layer_option = [58]

    ##change models
    if mode != "graph_compare":
        if model_name.split('_')[0] == "LeNet5":
            model = LeNet5(device).to(device)

        elif model_name.split('_')[0] == "CustomMLP":
            model = CustomMLP(layer_option).to(device)

    ##change model name
    if options:
        model_name = model_name + '_' + options

    if options == "weight_decay":
        weight_decay = args.w
        gausian_noise_mean = 0.
        gausian_noise_std = 0.
        model_name += '_' + str(weight_decay)
    elif options == "insert_noise":
        weight_decay = 0.
        gausian_noise_mean = args.mean
        gausian_noise_std = args.std
        model_name += '_s' + str(gausian_noise_std) + "_m" + str(
            gausian_noise_mean)
    else:
        weight_decay = 0.

    ##change criterion
    criterion = CrossEntropyLoss()

    #Custom TimeModule
    mytime = CheckTime()

    if mode == "train":

        # Load Dataset
        print(
            "{} Start Loading Train Dataset ==================================="
            .format(mytime.get_running_time_str()))

        train_dataset = dataset.MNIST(train_data_dir, gausian_noise_mean,
                                      gausian_noise_std)
        train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)

        # initiate optimizer
        optimizer = SGD(model.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=weight_decay)

        # If use checkpoint ...

        if use_ckpt:
            ckpt_files = glob(model_path + '{}_model_*.pt'.format(model_name))
            ckpt_files.sort()

            ckpt_model_path = ckpt_files[-1]
            epoch_info = torch.load(ckpt_model_path, map_location=device)

            start_epoch = epoch_info['epoch'] - 1

            model.load_state_dict(epoch_info['model'])
            optimizer.load_state_dic(epoch_info['optimizer'])

            total_trn_loss = epoch_info['total_trn_loss']
            total_trn_acc = epoch_info['total_trn_acc']

        else:
            start_epoch = 0
            total_trn_loss = []
            total_trn_acc = []

        # Check Random Parameter Model Loss & Accuracy
        print(
            "{} Check Random Parameter Model {}========================================= "
            .format(mytime.get_running_time_str(), model_name))

        with torch.no_grad():
            trn_loss, trn_acc = test(model, train_dataloader, device,
                                     criterion)
            total_trn_loss.append(trn_loss.item())
            total_trn_acc.append(trn_acc.item())
            i = 0

            torch.save(
                {
                    'epoch': i,
                    'model': model.state_dict(),
                    'opimizer': optimizer.state_dict(),
                    'total_trn_loss': total_trn_loss,
                    'total_trn_acc': total_trn_acc
                }, model_path + '{}_model_{:04d}.pt'.format(model_name, i))

        print("{} train {} // epoch: {} // loss: {:.6f} // accuracy: {:.2f} ".
              format(mytime.get_running_time_str(), model_name, i, trn_loss,
                     trn_acc))

        # Start traing model
        print("{} Start Training {}========================================= ".
              format(mytime.get_running_time_str(), model_name))
        for i in range(start_epoch, epoch):
            trn_loss, trn_acc = train(model, train_dataloader, device,
                                      criterion, optimizer)

            total_trn_loss.append(trn_loss.item())
            total_trn_acc.append(trn_acc.item())

            torch.save(
                {
                    'epoch': i,
                    'model': model.state_dict(),
                    'opimizer': optimizer.state_dict(),
                    'total_trn_loss': total_trn_loss,
                    'total_trn_acc': total_trn_acc
                }, model_path + '{}_model_{:04d}.pt'.format(model_name, i + 1))

            print(
                "{} train {} // epoch: {} // loss: {:.6f} // accuracy: {:.2f} "
                .format(mytime.get_running_time_str(), model_name, i + 1,
                        trn_loss, trn_acc))

    if mode == "test":
        #Start Loading Test Dataset
        print(
            "{} Start Loading Test Dataset ==================================="
            .format(mytime.get_running_time_str()))
        test_dataset = dataset.MNIST(test_data_dir)
        test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)

        # Start Testing model
        with torch.no_grad():

            ckpt_files = glob(model_path + '{}_model_*.pt'.format(model_name))
            ckpt_files.sort()

            total_tst_loss = []
            total_tst_acc = []

            for i, ckpt_model_path in enumerate(ckpt_files):

                epoch_info = torch.load(ckpt_model_path, map_location=device)

                model.load_state_dict(epoch_info['model'])

                tst_loss, tst_acc = test(model, test_dataloader, device,
                                         criterion)

                total_tst_loss.append(tst_loss.item())
                total_tst_acc.append(tst_acc.item())

                epoch_info['total_tst_loss'] = total_tst_loss
                epoch_info['total_tst_acc'] = total_tst_acc

                torch.save(epoch_info, ckpt_model_path)

                print(
                    "{} test {} // model_num: {} // loss: {:.6f} // accuracy: {:.2f} "
                    .format(mytime.get_running_time_str(), model_name, i,
                            tst_loss, tst_acc))

    if mode == "graph":

        #Load models to draw graph
        ckpt_files = glob(model_path + '{}_model_*.pt'.format(model_name))
        ckpt_files.sort()

        epoch_info = torch.load(ckpt_files[-1])

        #initiate loss and accuracy dictionary
        loss_dic = {}
        acc_dic = {}

        #add loss and accuracy list
        loss_dic['train'] = epoch_info['total_trn_loss']
        loss_dic['test'] = epoch_info['total_tst_loss']

        acc_dic['train'] = epoch_info['total_trn_acc']
        acc_dic['test'] = epoch_info['total_tst_acc']

        num_epoch = len(loss_dic['train'])

        #Draw Graph per model: trn_loss + tst_loss
        graph_name = "Loss (model - {}) ".format(model_name)
        draw_model_graph(graph_name,
                         num_epoch,
                         loss_dic,
                         graph_mode="loss",
                         save=args.s,
                         zoom_plot=args.z)

        #Draw Graph per model: trn_acc + tst_acc
        graph_name = "Accuracy (model - {}) ".format(model_name)
        draw_model_graph(graph_name,
                         num_epoch,
                         acc_dic,
                         graph_mode="acc",
                         save=args.s,
                         zoom_plot=args.z)

    if mode == "graph_compare":
        tst_loss_dic = {}
        tst_acc_dic = {}

        print(models_name)

        #Load pre-defined models
        for model_name in models_name:
            model_file_name = model_path + '{}_model_{:04d}.pt'.format(
                model_name, epoch)
            epoch_info = torch.load(model_file_name)

            tst_loss_dic[model_name] = epoch_info['total_tst_loss']
            tst_acc_dic[model_name] = epoch_info['total_tst_acc']

            num_epoch = len(tst_loss_dic[model_name])

        #Comparison models: tst_loss
        graph_name = "Compare Loss"
        draw_model_graph(graph_name,
                         num_epoch,
                         tst_loss_dic,
                         graph_mode="loss",
                         save=args.s,
                         zoom_plot=args.z)

        #Comparison models: tst_acc
        graph_name = "Compare Accuracy"
        draw_model_graph(graph_name,
                         num_epoch,
                         tst_acc_dic,
                         graph_mode="acc",
                         save=args.s,
                         zoom_plot=args.z)
Ejemplo n.º 18
0
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best-{}.pth.tar'.format(arch))


model = Model().cuda()
optimizer = SGD(filter(lambda p: p.requires_grad, model.parameters()), lr,
                momentum, weight_decay)
criterion = nn.CrossEntropyLoss().cuda()

best_loss = 10
for epoch in range(epochs):
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    # evaluate on validation set
    loss = validate(val_loader, model, criterion)

    # remember best prec@1 and save checkpoint
    is_best = loss < best_loss
    best_loss = min(loss, best_loss)
    print(' * Best Loss: {}'.format(best_loss))
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'arch': arch,
            'state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best)
Ejemplo n.º 19
0
class Trainer(object):
    def __init__(self, config):
        self.config = config
        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')

        start_time = datetime.datetime.now().strftime('%m%d_%H%M%S')
        self.log_path = os.path.join(config['train']['save_dir'], start_time)

        tb_path = os.path.join(self.log_path, 'logs')
        mkdir_p(tb_path)
        self.writer = WriterTensorboardX(tb_path)

        data_manager = CSVDataManager(config['data'])
        self.data_loader = data_manager.get_loader('train')
        self.valid_data_loader = data_manager.get_loader('val')

        self.model = AttentionalFactorizationMachine(data_manager.dims, config)
        self.model = self.model.to(self.device)

        trainable_params = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        self.optimizer = SGD(trainable_params, **config['optimizer'])
        self.lr_scheduler = StepLR(self.optimizer, **config['lr_scheduler'])

        self.best_val_loss = float('inf')
        self.satur_count = 0

    def _train_epoch(self, epoch):
        self.model.train()

        total_loss = 0
        self.writer.set_step(epoch)
        _trange = tqdm(self.data_loader, leave=True, desc='')

        for batch_idx, batch in enumerate(_trange):
            batch = [b.to(self.device) for b in batch]

            data, target = batch[:-1], batch[-1]
            # data -> users, items, gens

            self.optimizer.zero_grad()
            output = self.model(data)

            loss = F.mse_loss(output, target)
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

            if batch_idx % 10 == 0:
                _str = 'Train Epoch: {} Loss: {:.6f}'.format(
                    epoch, loss.item())
                _trange.set_description(_str)

        loss = total_loss / len(self.data_loader)
        self.writer.add_scalar('loss', loss)

        log = {'loss': loss}

        val_log = self._valid_epoch(epoch)
        log = {**log, **val_log}

        self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):

        self.model.eval()
        total_val_loss = 0

        self.writer.set_step(epoch, 'valid')

        with torch.no_grad():

            for batch_idx, batch in enumerate(self.valid_data_loader):
                batch = [b.to(self.device) for b in batch]

                data, target = batch[:-1], batch[-1]

                output = self.model(data)
                loss = F.mse_loss(output, target)

                total_val_loss += loss.item()

            val_loss = total_val_loss / len(self.valid_data_loader)

            self.writer.add_scalar('loss', val_loss)

            # for name, param in self.model.named_parameters():
            #    if param.requires_grad:
            #        self.writer.add_histogram(name, param.clone().cpu().numpy(), bins='doane')

        return {'val_loss': val_loss}

    def train(self):
        print(self.model)

        for epoch in range(1, self.config['train']['epochs'] + 1):

            result = self._train_epoch(epoch)

            c_lr = self.optimizer.param_groups[0]['lr']
            self.writer.add_scalar('lr', c_lr)

            log = pd.DataFrame([result]).T
            log.columns = ['']
            print(log)

            if self.best_val_loss > result['val_loss']:
                print('[IMPROVED]')
                chk_path = os.path.join(self.log_path, 'checkpoints')
                mkdir_p(chk_path)

                state = {
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict()
                }

                torch.save(state, os.path.join(chk_path, 'model_best.pth'))
                with open(os.path.join(chk_path, 'config.json'), 'w') as wj:
                    json.dump(self.config, wj)
            else:
                self.satur_count += 1

            if self.satur_count > self.config['train']['early_stop']:
                break
Ejemplo n.º 20
0
class Trainer(object):
    def __init__(self, config_path=None, **kwargs):

        # general
        self.run_name = None

        # code parameters
        self.use_ecc = None
        self.n_symbols = None

        # channel
        self.memory_length = None
        self.channel_type = None
        self.channel_coefficients = None
        self.noisy_est_var = None
        self.fading_in_channel = None
        self.fading_in_decoder = None
        self.fading_taps_type = None
        self.subframes_in_frame = None
        self.gamma = None

        # validation hyperparameters
        self.val_block_length = None
        self.val_frames = None
        self.val_SNR_start = None
        self.val_SNR_end = None
        self.val_SNR_step = None
        self.eval_mode = None

        # training hyperparameters
        self.train_block_length = None
        self.train_frames = None
        self.train_minibatch_num = None
        self.train_minibatch_size = None
        self.train_SNR_start = None
        self.train_SNR_end = None
        self.train_SNR_step = None
        self.lr = None  # learning rate
        self.loss_type = None
        self.optimizer_type = None

        # self-supervised online training
        self.self_supervised = None
        self.self_supervised_iterations = None
        self.ser_thresh = None
        self.meta_lr = None
        self.MAML = None
        self.online_meta = None
        self.weights_init = None
        self.window_size = None
        self.buffer_empty = None
        self.meta_train_iterations = None
        self.meta_j_num = None
        self.meta_subframes = None

        # seed
        self.noise_seed = None
        self.word_seed = None

        # weights dir
        self.weights_dir = None

        # if any kwargs are passed, initialize the dict with them
        self.initialize_by_kwargs(**kwargs)

        # initializes all none parameters above from config
        self.param_parser(config_path)

        # initializes word and noise generator from seed
        self.rand_gen = np.random.RandomState(self.noise_seed)
        self.word_rand_gen = np.random.RandomState(self.word_seed)
        self.n_states = 2**self.memory_length

        # initialize matrices, datasets and detector
        self.initialize_dataloaders()
        self.initialize_detector()
        self.initialize_meta_detector()

        # calculate data subframes indices. We will calculate ser only over these values.
        self.data_indices = torch.Tensor(
            list(
                filter(lambda x: x % self.subframes_in_frame != 0, [
                    i for i in range(self.val_frames * self.subframes_in_frame)
                ]))).long()

    def initialize_by_kwargs(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

    def param_parser(self, config_path: str):
        """
        Parse the config, load all attributes into the trainer
        :param config_path: path to config
        """
        if config_path is None:
            config_path = CONFIG_PATH

        with open(config_path) as f:
            self.config = yaml.load(f, Loader=yaml.FullLoader)

        # set attribute of Trainer with every config item
        for k, v in self.config.items():
            try:
                if getattr(self, k) is None:
                    setattr(self, k, v)
            except AttributeError:
                pass

        if self.weights_dir is None:
            self.weights_dir = os.path.join(WEIGHTS_DIR, self.run_name)
            if not os.path.exists(self.weights_dir) and len(self.weights_dir):
                os.makedirs(self.weights_dir)
                # save config in output dir
                copyfile(config_path,
                         os.path.join(self.weights_dir, "config.yaml"))

    def get_name(self):
        return self.__name__()

    def initialize_detector(self):
        """
        Every trainer must have some base detector model
        """
        self.detector = None
        pass

    def initialize_meta_detector(self):
        """
        Every trainer must have some base detector model
        """
        self.meta_detector = None
        pass

    def check_eval_mode(self):
        if self.eval_mode != 'aggregated' and self.eval_mode != 'by_word':
            raise ValueError("No such eval mode!!!")

    # calculate train loss
    def calc_loss(self, soft_estimation: torch.Tensor,
                  transmitted_words: torch.Tensor) -> torch.Tensor:
        """
         Every trainer must have some loss calculation
        """
        pass

    # setup the optimization algorithm
    def deep_learning_setup(self):
        """
        Sets up the optimizer and loss criterion
        """
        if self.optimizer_type == 'Adam':
            self.optimizer = Adam(filter(lambda p: p.requires_grad,
                                         self.detector.parameters()),
                                  lr=self.lr)
        elif self.optimizer_type == 'RMSprop':
            self.optimizer = RMSprop(filter(lambda p: p.requires_grad,
                                            self.detector.parameters()),
                                     lr=self.lr)
        elif self.optimizer_type == 'SGD':
            self.optimizer = SGD(filter(lambda p: p.requires_grad,
                                        self.detector.parameters()),
                                 lr=self.lr)
        else:
            raise NotImplementedError("No such optimizer implemented!!!")
        if self.loss_type == 'BCE':
            self.criterion = BCELoss().to(device)
        elif self.loss_type == 'CrossEntropy':
            self.criterion = CrossEntropyLoss().to(device)
        elif self.loss_type == 'MSE':
            self.criterion = MSELoss().to(device)
        else:
            raise NotImplementedError("No such loss function implemented!!!")

    def initialize_dataloaders(self):
        """
        Sets up the data loader - a generator from which we draw batches, in iterations
        """
        self.snr_range = {
            'train':
            np.arange(self.train_SNR_start,
                      self.train_SNR_end + 1,
                      step=self.train_SNR_step),
            'val':
            np.arange(self.val_SNR_start,
                      self.val_SNR_end + 1,
                      step=self.val_SNR_step)
        }
        self.frames_per_phase = {
            'train': self.train_frames,
            'val': self.val_frames
        }
        self.block_lengths = {
            'train': self.train_block_length,
            'val': self.val_block_length
        }
        self.channel_coefficients = {
            'train': 'time_decay',
            'val': self.channel_coefficients
        }
        self.transmission_lengths = {
            'train':
            self.train_block_length if not self.use_ecc else
            self.train_block_length + 8 * self.n_symbols,
            'val':
            self.val_block_length
            if not self.use_ecc else self.val_block_length + 8 * self.n_symbols
        }
        self.channel_dataset = {
            phase: ChannelModelDataset(
                channel_type=self.channel_type,
                block_length=self.block_lengths[phase],
                transmission_length=self.transmission_lengths[phase],
                words=self.frames_per_phase[phase] * self.subframes_in_frame,
                memory_length=self.memory_length,
                channel_coefficients=self.channel_coefficients[phase],
                random=self.rand_gen,
                word_rand_gen=self.word_rand_gen,
                noisy_est_var=self.noisy_est_var,
                use_ecc=self.use_ecc,
                n_symbols=self.n_symbols,
                fading_taps_type=self.fading_taps_type,
                fading_in_channel=self.fading_in_channel,
                fading_in_decoder=self.fading_in_decoder,
                phase=phase)
            for phase in ['train', 'val']
        }
        self.dataloaders = {
            phase: torch.utils.data.DataLoader(self.channel_dataset[phase])
            for phase in ['train', 'val']
        }

    def online_training(self, tx: torch.Tensor, rx: torch.Tensor):
        pass

    def single_eval_at_point(self, snr: float, gamma: float) -> float:
        """
        Evaluation at a single snr.
        :param snr: indice of snr in the snrs vector
        :return: ser for batch
        """
        # draw words of given gamma for all snrs
        transmitted_words, received_words = self.channel_dataset[
            'val'].__getitem__(snr_list=[snr], gamma=gamma)

        # decode and calculate accuracy
        detected_words = self.detector(received_words, 'val', snr, gamma)

        if self.use_ecc:
            decoded_words = [
                decode(detected_word, self.n_symbols)
                for detected_word in detected_words.cpu().numpy()
            ]
            detected_words = torch.Tensor(decoded_words).to(device)

        ser, fer, err_indices = calculate_error_rates(
            detected_words[self.data_indices],
            transmitted_words[self.data_indices])

        return ser

    def gamma_eval(self, gamma: float) -> np.ndarray:
        """
        Evaluation at a single gamma value.
        :return: ser for batch.
        """
        ser_total = np.zeros(len(self.snr_range['val']))
        for snr_ind, snr in enumerate(self.snr_range['val']):
            self.load_weights(snr, gamma)
            ser_total[snr_ind] = self.single_eval_at_point(snr, gamma)
        return ser_total

    def evaluate_at_point(self) -> np.ndarray:
        """
        Monte-Carlo simulation over validation SNRs range
        :return: ber, fer, iterations vectors
        """
        ser_total = np.zeros(len(self.snr_range['val']))
        with torch.no_grad():
            print(f'Starts evaluation at gamma {self.gamma}')
            start = time()
            ser_total += self.gamma_eval(self.gamma)
            print(f'Done. time: {time() - start}, ser: {ser_total}')
        return ser_total

    def eval_by_word(self, snr: float,
                     gamma: float) -> Union[float, np.ndarray]:
        if self.self_supervised:
            self.deep_learning_setup()
        total_ser = 0
        # draw words of given gamma for all snrs
        transmitted_words, received_words = self.channel_dataset[
            'val'].__getitem__(snr_list=[snr], gamma=gamma)
        ser_by_word = np.zeros(transmitted_words.shape[0])
        # saved detector is used to initialize the decoder in meta learning loops
        self.saved_detector = copy.deepcopy(self.detector)
        # query for all detected words
        if self.buffer_empty:
            buffer_rx = torch.empty([0, received_words.shape[1]]).to(device)
            buffer_tx = torch.empty([0, received_words.shape[1]]).to(device)
            buffer_ser = torch.empty([0]).to(device)
        else:
            # draw words from different channels
            buffer_tx, buffer_rx = self.channel_dataset['train'].__getitem__(
                snr_list=[snr], gamma=gamma)
            buffer_ser = torch.zeros(buffer_rx.shape[0]).to(device)
            buffer_tx = torch.cat([
                torch.Tensor(
                    encode(transmitted_word.int().cpu().numpy(),
                           self.n_symbols).reshape(1, -1)).to(device)
                for transmitted_word in buffer_tx
            ],
                                  dim=0)

        support_idx = torch.arange(-self.window_size - 1, -1).long().to(device)
        query_idx = -1 * torch.ones(1).long().to(device)

        for count, (transmitted_word, received_word) in enumerate(
                zip(transmitted_words, received_words)):
            transmitted_word, received_word = transmitted_word.reshape(
                1, -1), received_word.reshape(1, -1)
            # detect
            detected_word = self.detector(received_word, 'val', snr, gamma,
                                          count)
            if count in self.data_indices:
                # decode
                decoded_word = [
                    decode(detected_word, self.n_symbols)
                    for detected_word in detected_word.cpu().numpy()
                ]
                decoded_word = torch.Tensor(decoded_word).to(device)
                # calculate accuracy
                ser, fer, err_indices = calculate_error_rates(
                    decoded_word, transmitted_word)
                # encode word again
                decoded_word_array = decoded_word.int().cpu().numpy()
                encoded_word = torch.Tensor(
                    encode(decoded_word_array,
                           self.n_symbols).reshape(1, -1)).to(device)
                errors_num = torch.sum(torch.abs(encoded_word -
                                                 detected_word)).item()
                print('*' * 20)
                print(f'current: {count, ser, errors_num}')
                total_ser += ser
                ser_by_word[count] = ser
            else:
                print('*' * 20)
                print(f'current: {count}, Pilot')
                # encode word again
                decoded_word_array = transmitted_word.int().cpu().numpy()
                encoded_word = torch.Tensor(
                    encode(decoded_word_array,
                           self.n_symbols).reshape(1, -1)).to(device)
                ser = 0

            # save the encoded word in the buffer
            if ser <= self.ser_thresh:
                buffer_rx = torch.cat([buffer_rx, received_word])
                buffer_tx = torch.cat([
                    buffer_tx,
                    detected_word.reshape(1, -1)
                    if ser > 0 else encoded_word.reshape(1, -1)
                ],
                                      dim=0)
                buffer_ser = torch.cat(
                    [buffer_ser,
                     torch.FloatTensor([ser]).to(device)])
                if not self.buffer_empty:
                    buffer_rx = buffer_rx[1:]
                    buffer_tx = buffer_tx[1:]
                    buffer_ser = buffer_ser[1:]

            if self.online_meta and count % self.meta_subframes == 0 and count >= self.meta_subframes and \
                    buffer_rx.shape[0] > 2:  # self.subframes_in_frame
                print('meta-training')
                self.meta_weights_init()
                for i in range(self.meta_train_iterations):
                    j_hat_values = torch.unique(
                        torch.randint(low=0,
                                      high=buffer_rx.shape[0] - 2,
                                      size=[self.meta_j_num])).to(device)
                    for j_hat in j_hat_values:
                        cur_support_idx = j_hat + support_idx + 1
                        cur_query_idx = j_hat + query_idx + 1
                        self.meta_train_loop(buffer_rx, buffer_tx,
                                             cur_support_idx, cur_query_idx)
                copy_model(source_model=self.detector,
                           dest_model=self.saved_detector)

            if self.self_supervised and ser <= self.ser_thresh:
                # use last word inserted in the buffer for training
                self.online_training(buffer_tx[-1].reshape(1, -1),
                                     buffer_rx[-1].reshape(1, -1))

            if (count + 1) % 10 == 0:
                print(
                    f'Self-supervised: {count + 1}/{transmitted_words.shape[0]}, SER {total_ser / (count + 1)}'
                )

        total_ser /= transmitted_words.shape[0]
        print(f'Final ser: {total_ser}')
        return ser_by_word

    def meta_weights_init(self):
        if self.weights_init == 'random':
            self.initialize_detector()
            self.deep_learning_setup()
        elif self.weights_init == 'last_frame':
            copy_model(source_model=self.saved_detector,
                       dest_model=self.detector)
        elif self.weights_init == 'meta_training':
            snr = self.snr_range['val'][0]
            self.load_weights(snr, self.gamma)
        else:
            raise ValueError('No such weights init!!!')

    def evaluate(self) -> np.ndarray:
        """
        Evaluation either happens in a point aggregation way, or in a word-by-word fashion
        """
        # eval with training
        self.check_eval_mode()
        if self.eval_mode == 'by_word':
            if not self.use_ecc:
                raise ValueError('Only supports ecc')
            snr = self.snr_range['val'][0]
            self.load_weights(snr, self.gamma)
            return self.eval_by_word(snr, self.gamma)
        else:
            return self.evaluate_at_point()

    def meta_train(self):
        """
        Main meta-training loop. Runs in minibatches, each minibatch is split to pairs of following words.
        The pairs are comprised of (support,query) words.
        Evaluates performance over validation SNRs.
        Saves weights every so and so iterations.
        """
        # initialize weights and loss
        for snr in self.snr_range['train']:

            print(f'SNR - {snr}, Gamma - {self.gamma}')
            # initialize weights and loss
            self.initialize_detector()
            self.deep_learning_setup()

            for minibatch in range(1, self.train_minibatch_num + 1):
                # draw words from different channels
                transmitted_words, received_words = self.channel_dataset[
                    'train'].__getitem__(snr_list=[snr], gamma=self.gamma)
                support_idx = torch.arange(-self.window_size - 1,
                                           -1).long().to(device)
                query_idx = -1 * torch.ones(1).long().to(device)
                j_hat_values = torch.unique(
                    torch.randint(low=self.window_size,
                                  high=transmitted_words.shape[0],
                                  size=[self.meta_j_num])).to(device)
                if self.use_ecc:
                    transmitted_words = torch.cat([
                        torch.Tensor(
                            encode(transmitted_word.int().cpu().numpy(),
                                   self.n_symbols).reshape(1, -1)).to(device)
                        for transmitted_word in transmitted_words
                    ],
                                                  dim=0)

                loss_query = 0
                for j_hat in j_hat_values:
                    cur_support_idx = j_hat + support_idx + 1
                    cur_query_idx = j_hat + query_idx + 1
                    loss_query += self.meta_train_loop(received_words,
                                                       transmitted_words,
                                                       cur_support_idx,
                                                       cur_query_idx)

                # evaluate performance
                ser = self.single_eval_at_point(snr, self.gamma)
                print(
                    f'Minibatch {minibatch}, ser - {ser}, loss - {loss_query}')
                # save best weights
                self.save_weights(float(loss_query), snr, self.gamma)

    def meta_train_loop(self, received_words: torch.Tensor,
                        transmitted_words: torch.Tensor,
                        support_idx: torch.Tensor, query_idx: torch.Tensor):
        # divide the words to following pairs - (support,query)
        support_tx, support_rx = transmitted_words[
            support_idx], received_words[support_idx]
        query_tx, query_rx = transmitted_words[query_idx], received_words[
            query_idx]

        # local update (with support set)
        para_list_detector = list(
            map(lambda p: p[0], zip(self.detector.parameters())))
        soft_estimation_supp = self.meta_detector(support_rx, 'train',
                                                  para_list_detector)
        loss_supp = self.calc_loss(soft_estimation=soft_estimation_supp,
                                   transmitted_words=support_tx)

        # set create_graph to True for MAML, False for FO-MAML
        local_grad = torch.autograd.grad(loss_supp,
                                         para_list_detector,
                                         create_graph=self.MAML)
        updated_para_list_detector = list(
            map(lambda p: p[1] - self.meta_lr * p[0],
                zip(local_grad, para_list_detector)))

        # meta-update (with query set) should be same channel with support set
        soft_estimation_query = self.meta_detector(query_rx, 'train',
                                                   updated_para_list_detector)
        loss_query = self.calc_loss(soft_estimation=soft_estimation_query,
                                    transmitted_words=query_tx)
        meta_grad = torch.autograd.grad(loss_query,
                                        para_list_detector,
                                        create_graph=False)

        ind_param = 0
        for param in self.detector.parameters():
            param.grad = None  # zero_grad
            param.grad = meta_grad[ind_param]
            ind_param += 1

        self.optimizer.step()
        return loss_query

    def train(self):
        """
        Main training loop. Runs in minibatches.
        Evaluates performance over validation SNRs.
        Saves weights every so and so iterations.
        """
        # batches loop
        for snr in self.snr_range['train']:
            print(f'SNR - {snr}, Gamma - {self.gamma}')

            # initialize weights and loss
            self.initialize_detector()
            self.deep_learning_setup()
            best_ser = math.inf

            for minibatch in range(1, self.train_minibatch_num + 1):
                # draw words
                transmitted_words, received_words = self.channel_dataset[
                    'train'].__getitem__(snr_list=[snr], gamma=self.gamma)
                # run training loops
                current_loss = 0
                for i in range(self.train_frames * self.subframes_in_frame):
                    # pass through detector
                    soft_estimation = self.detector(
                        received_words[i].reshape(1, -1), 'train')
                    current_loss += self.run_train_loop(
                        soft_estimation, transmitted_words[i].reshape(1, -1))

                # evaluate performance
                ser = self.single_eval_at_point(snr, self.gamma)
                print(
                    f'Minibatch {minibatch}, ser - {ser}, loss {current_loss}')
                # save best weights
                if ser < best_ser:
                    self.save_weights(current_loss, snr, self.gamma)
                    best_ser = ser

            print(f'best ser - {best_ser}')
            print('*' * 50)

    def run_train_loop(self, soft_estimation: torch.Tensor,
                       transmitted_words: torch.Tensor):
        # calculate loss
        loss = self.calc_loss(soft_estimation=soft_estimation,
                              transmitted_words=transmitted_words)
        # if loss is Nan inform the user
        if torch.sum(torch.isnan(loss)):
            print('Nan value')
            return np.nan
        current_loss = loss.item()
        # back propagation
        for param in self.detector.parameters():
            param.grad = None
        loss.backward()
        self.optimizer.step()
        return current_loss

    def save_weights(self, current_loss: float, snr: float, gamma: float):
        torch.save(
            {
                'model_state_dict': self.detector.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': current_loss
            }, os.path.join(self.weights_dir, f'snr_{snr}_gamma_{gamma}.pt'))

    def load_weights(self, snr: float, gamma: float):
        """
        Loads detector's weights defined by the [snr,gamma] from checkpoint, if exists
        """
        if os.path.join(self.weights_dir, f'snr_{snr}_gamma_{gamma}.pt'):
            print(f'loading model from snr {snr} and gamma {gamma}')
            weights_path = os.path.join(self.weights_dir,
                                        f'snr_{snr}_gamma_{gamma}.pt')
            if not os.path.isfile(weights_path):
                # if weights do not exist, train on the synthetic channel. Then validate on the test channel.
                self.fading_taps_type = 1
                os.makedirs(self.weights_dir, exist_ok=True)
                self.train()
                self.fading_taps_type = 2
            checkpoint = torch.load(weights_path)
            try:
                self.detector.load_state_dict(checkpoint['model_state_dict'])
            except Exception:
                raise ValueError("Wrong run directory!!!")
        else:
            print(
                f'No checkpoint for snr {snr} and gamma {gamma} in run "{self.run_name}", starting from scratch'
            )

    def select_batch(
        self, gt_examples: torch.LongTensor, soft_estimation: torch.Tensor
    ) -> Tuple[torch.LongTensor, torch.Tensor]:
        """
        Select a batch from the input and gt labels
        :param gt_examples: training labels
        :param soft_estimation: the soft approximation, distribution over states (per word)
        :return: selected batch from the entire "epoch", contains both labels and the NN soft approximation
        """
        rand_ind = torch.multinomial(
            torch.arange(gt_examples.shape[0]).float(),
            self.train_minibatch_size).long().to(device)
        return gt_examples[rand_ind], soft_estimation[rand_ind]
Ejemplo n.º 21
0
def train(model,
          state,
          path,
          annotations,
          val_path,
          val_annotations,
          resize,
          max_size,
          jitter,
          batch_size,
          iterations,
          val_iterations,
          mixed_precision,
          lr,
          warmup,
          milestones,
          gamma,
          rank=0,
          world=1,
          no_apex=False,
          use_dali=True,
          verbose=True,
          metrics_url=None,
          logdir=None,
          rotate_augment=False,
          augment_brightness=0.0,
          augment_contrast=0.0,
          augment_hue=0.0,
          augment_saturation=0.0,
          regularization_l2=0.0001,
          rotated_bbox=False,
          absolute_angle=False):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.to(memory_format=torch.channels_last).cuda()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(),
                    lr=lr,
                    weight_decay=regularization_l2,
                    momentum=0.9)

    is_master = rank == 0
    if not no_apex:
        loss_scale = "dynamic" if use_dali else "128.0"
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2' if mixed_precision else 'O0',
            keep_batchnorm_fp32=True,
            loss_scale=loss_scale,
            verbosity=is_master)

    if world > 1:
        model = DDP(model, device_ids=[rank]) if no_apex else ADDP(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer, schedule)
    if 'scheduler' in state:
        scheduler.load_state_dict(state['scheduler'])

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    if rotated_bbox:
        if use_dali:
            raise NotImplementedError(
                "This repo does not currently support DALI for rotated bbox detections."
            )
        data_iterator = RotatedDataIterator(
            path,
            jitter,
            max_size,
            batch_size,
            stride,
            world,
            annotations,
            training=True,
            rotate_augment=rotate_augment,
            augment_brightness=augment_brightness,
            augment_contrast=augment_contrast,
            augment_hue=augment_hue,
            augment_saturation=augment_saturation,
            absolute_angle=absolute_angle)
    else:
        data_iterator = (DaliDataIterator if use_dali else DataIterator)(
            path,
            jitter,
            max_size,
            batch_size,
            stride,
            world,
            annotations,
            training=True,
            rotate_augment=rotate_augment,
            augment_brightness=augment_brightness,
            augment_contrast=augment_contrast,
            augment_hue=augment_hue,
            augment_saturation=augment_saturation)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'GPU' if world == 1 else 'GPUs'))
        print('     batch: {}, precision: {}'.format(
            batch_size, 'mixed' if mixed_precision else 'full'))
        print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned')
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if is_master and logdir is not None:
        from torch.utils.tensorboard import SummaryWriter
        if verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(log_dir=logdir)

    scaler = GradScaler()
    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            if iteration >= iterations:
                break

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            if not no_apex:
                cls_loss, box_loss = model([
                    data.contiguous(memory_format=torch.channels_last), target
                ])
            else:
                with autocast():
                    cls_loss, box_loss = model([
                        data.contiguous(memory_format=torch.channels_last),
                        target
                    ])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            if not no_apex:
                with amp.scale_loss(cls_loss + box_loss,
                                    optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
            else:
                scaler.scale(cls_loss + box_loss).backward()
                scaler.step(optimizer)
                scaler.update()

            scheduler.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean(
            ).clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60
                              or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration,
                                                 iterations,
                                                 len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'],
                                                       batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(
                        profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size /
                                                  profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if is_master and logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate,
                                      iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(
                        metrics_url, {
                            'focal loss': mean(cls_losses),
                            'box loss': mean(box_losses),
                            'im_s': batch_size / profiler.means['train'],
                            'lr': learning_rate
                        })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations
                                    or iteration % val_iterations == 0):
                stats = infer(model,
                              val_path,
                              None,
                              resize,
                              max_size,
                              batch_size,
                              annotations=val_annotations,
                              mixed_precision=mixed_precision,
                              is_master=is_master,
                              world=world,
                              use_dali=use_dali,
                              no_apex=no_apex,
                              is_validation=True,
                              verbose=False,
                              rotated_bbox=rotated_bbox)
                model.train()
                if is_master and logdir is not None and stats is not None:
                    writer.add_scalar('Validation_Precision/mAP', stats[0],
                                      iteration)
                    writer.add_scalar('Validation_Precision/[email protected]',
                                      stats[1], iteration)
                    writer.add_scalar('Validation_Precision/[email protected]',
                                      stats[2], iteration)
                    writer.add_scalar('Validation_Precision/mAP (small)',
                                      stats[3], iteration)
                    writer.add_scalar('Validation_Precision/mAP (medium)',
                                      stats[4], iteration)
                    writer.add_scalar('Validation_Precision/mAP (large)',
                                      stats[5], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 1 Dets)',
                                      stats[6], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 10 Dets)',
                                      stats[7], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 100 Dets)',
                                      stats[8], iteration)
                    writer.add_scalar('Validation_Recall/mAR (small)',
                                      stats[9], iteration)
                    writer.add_scalar('Validation_Recall/mAR (medium)',
                                      stats[10], iteration)
                    writer.add_scalar('Validation_Recall/mAR (large)',
                                      stats[11], iteration)

            if (iteration == iterations
                    and not rotated_bbox) or (iteration > iterations
                                              and rotated_bbox):
                break

    if is_master and logdir is not None:
        writer.close()
Ejemplo n.º 22
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(
        root=args.source_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=args.resize_ratio,
                                scale=(0.5, 1.)),
            T.ColorJitter(brightness=0.3, contrast=0.3),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(
        root=args.target_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=(2., 2.),
                                scale=(0.5, 1.)),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(
        root=args.target_root,
        split='val',
        transforms=T.Compose([
            T.Resize(image_size=args.test_input_size,
                     label_size=args.test_output_size),
            T.NormalizeAndTranspose(),
        ]),
    )
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   pin_memory=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    num_classes = train_source_dataset.num_classes
    model = models.__dict__[args.arch](num_classes=num_classes).to(device)
    discriminator = Discriminator(num_classes=num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(model.get_parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    optimizer_d = Adam(discriminator.parameters(),
                       lr=args.lr_d,
                       betas=(0.9, 0.99))
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x:
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])
        args.start_epoch = checkpoint['epoch'] + 1

    # define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label).to(device)
    dann = DomainAdversarialEntropyLoss(discriminator)
    interp_train = nn.Upsample(size=args.train_size[::-1],
                               mode='bilinear',
                               align_corners=True)
    interp_val = nn.Upsample(size=args.test_output_size[::-1],
                             mode='bilinear',
                             align_corners=True)

    # define visualization function
    decode = train_source_dataset.decode_target

    def visualize(image, pred, label, prefix):
        """
        Args:
            image (tensor): 3 x H x W
            pred (tensor): C x H x W
            label (tensor): H x W
            prefix: prefix of the saving image
        """
        image = image.detach().cpu().numpy()
        pred = pred.detach().max(dim=0)[1].cpu().numpy()
        label = label.cpu().numpy()
        for tensor, name in [
            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))),
             "image"), (decode(label), "label"), (decode(pred), "pred")
        ]:
            tensor.save(logger.get_image_path("{}_{}.png".format(prefix,
                                                                 name)))

    if args.phase == 'test':
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           visualize, args)
        print(confmat)
        return

    # start training
    best_iou = 0.
    for epoch in range(args.start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, model, interp_train,
              criterion, dann, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           None, args)
        print(confmat.format(train_source_dataset.classes))
        acc_global, acc, iu = confmat.compute()

        # calculate the mean iou over partial classes
        indexes = [
            train_source_dataset.classes.index(name)
            for name in train_source_dataset.evaluate_classes
        ]
        iu = iu[indexes]
        mean_iou = iu.mean()

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizer.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'lr_scheduler_d': lr_scheduler_d.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if mean_iou > best_iou:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
        best_iou = max(best_iou, mean_iou)
        print("Target: {} Best: {}".format(mean_iou, best_iou))

    logger.close()
Ejemplo n.º 23
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc
    args.gpu = gpu
    assert args.gpu is not None
    print("Use GPU: {} for training".format(args.gpu))

    log = open(
        os.path.join(
            args.save_path,
            'log_seed{}{}.txt'.format(args.manualSeed,
                                      '_eval' if args.evaluate else '')), 'w')
    log = (log, args.gpu)

    net = models.__dict__[args.arch](pretrained=True)
    disable_dropout(net)
    net = to_bayesian(net, args.psi_init_range)
    net.apply(unfreeze)

    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log(
        "Number of parameters: {}".format(
            sum([p.numel() for p in net.parameters()])), log)
    print_log(str(args), log)

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url + ":" +
                                args.dist_port,
                                world_size=args.world_size,
                                rank=args.rank)
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net,
                                                        device_ids=[args.gpu])
    else:
        torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    mus, psis = [], []
    for name, param in net.named_parameters():
        if 'psi' in name: psis.append(param)
        else: mus.append(param)
    mu_optimizer = SGD(mus,
                       args.learning_rate,
                       args.momentum,
                       weight_decay=args.decay,
                       nesterov=(args.momentum > 0.0))

    psi_optimizer = PsiSGD(psis,
                           args.learning_rate,
                           args.momentum,
                           weight_decay=args.decay,
                           nesterov=(args.momentum > 0.0))

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume,
                                    map_location='cuda:{}'.format(args.gpu))
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(
                checkpoint['state_dict'] if args.distributed else {
                    k.replace('module.', ''): v
                    for k, v in checkpoint['state_dict'].items()
                })
            mu_optimizer.load_state_dict(checkpoint['mu_optimizer'])
            psi_optimizer.load_state_dict(checkpoint['psi_optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log("=> do not use any checkpoint for the model", log)

    cudnn.benchmark = True

    train_loader, ood_train_loader, test_loader, adv_loader, \
        fake_loader, adv_loader2 = load_dataset_ft(args)
    psi_optimizer.num_data = len(train_loader.dataset)

    if args.evaluate:
        evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net,
                 criterion, args, log, 20, 100)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            ood_train_loader.sampler.set_epoch(epoch)
        cur_lr, cur_slr = adjust_learning_rate(mu_optimizer, psi_optimizer,
                                               epoch, args)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format(
                                    time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, ood_train_loader, net,
                                     criterion, mu_optimizer, psi_optimizer,
                                     epoch, args, log)
        val_acc, val_los = 0, 0
        recorder.update(epoch, train_los, train_acc, val_acc, val_los)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        if args.gpu == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'mu_optimizer': mu_optimizer.state_dict(),
                    'psi_optimizer': psi_optimizer.state_dict(),
                }, False, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'log.png'))

    evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion,
             args, log, 20, 100)

    log[0].close()
Ejemplo n.º 24
0
        triplet_loss = TripletLoss(margin=margin).forward(
            anchor=anc_hard_embedding,
            positive=pos_hard_embedding,
            negative=neg_hard_embedding).cuda()

        triplet_loss_sum += triplet_loss.item()
        num_valid_training_triplets += len(anc_hard_embedding)

        optimizer_model.zero_grad()
        triplet_loss.backward()
        optimizer_model.step()

    avg_triplet_loss = 0 if (
        num_valid_training_triplets
        == 0) else triplet_loss_sum / num_valid_training_triplets

    print(
        'Epoch {}:\tAverage Triplet Loss: {:.4f}\tNumber of valid training triplets in epoch: {}'
        .format(epoch + 1, avg_triplet_loss, num_valid_training_triplets))

torch.save(
    {
        'epoch': epoch,
        'model_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer_model.state_dict(),
        'avg_triplet_loss': avg_triplet_loss,
        'valid_training_triplets': num_valid_training_triplets
    }, './train_checkpoints/' + 'checkpoint_' + str(total_triplets) + '_' +
    str(epoch) + '_' + str(num_valid_training_triplets) + '.tar')
Ejemplo n.º 25
0
def main(input_len, epochs_num, hidden_size, batch_size, output_size, lr):

    start = datetime.datetime(1999, 1, 8)
    end = datetime.datetime(2016, 12, 31)

    #test_start = datetime.datetime(2015, 1, 8)
    #test_end = datetime.datetime(2016, 12, 31)

    training_size = 0
    test_size = 0

    train_x, train_t, test_x, test_t = mkDataSet(start, end, input_len)

    model = Predictor(6, hidden_size, output_size)

    test_x = torch.Tensor(test_x)
    test_t = torch.Tensor(test_t)

    train_x = torch.Tensor(train_x)
    train_t = torch.Tensor(train_t)
    #print(test_x.size())
    #print(test_t.size())

    #print(test_x)
    #print(test_t)
    #exit

    #test_x = torch.Tensor(test_x)
    #test_t = torch.Tensor(test_t)

    dataset = TensorDataset(train_x, train_t)
    loader_train = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    dataset = TensorDataset(test_x, test_t)
    loader_test = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    #dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=4, shuffle=True,num_workers=2)

    #torch.backends.cudnn.benchmark=True

    optimizer = SGD(model.parameters(), lr)

    criterion = torch.nn.BCELoss(size_average=False)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[epochs_num * 0.3, epochs_num * 0.7],
        gamma=0.1,
        last_epoch=-1)

    loss_record = []
    count = 0

    for epoch in range(epochs_num):
        # training
        running_loss = 0.0
        training_accuracy = 0.0
        training_num = 0
        #scheduler.step()
        model.train()

        for i, data in enumerate(loader_train, 0):
            #入力データ・ラベルに分割
            # get the inputs
            inputs, labels = data

            # optimizerの初期化
            # zero the parameter gradients
            optimizer.zero_grad()

            #一連の流れ
            # forward + backward + optimize
            outputs = model(inputs)

            labels = labels.float()

            #ここでラベルデータに対するCross-Entropyがとられる
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # ロスの表示
            # print statistics
            #running_loss += loss.data[0]

            running_loss += loss.data.item() * 100

            training_accuracy += np.sum(
                np.abs((outputs.data - labels.data).numpy()) <= 0.5)

            training_num += np.sum(
                np.abs((outputs.data - labels.data).numpy()) != 10000)

        #test
        test_accuracy = 0.0
        test_num = 0
        model.eval()

        for i, data in enumerate(loader_test, 0):
            inputs, labels = data

            outputs = model(inputs)

            labels = labels.float()
            #print("#######################")
            #print(outputs)
            #print(labels)

            #print(output.t_(),label.t_())
            #print(np.abs((output.data - label.data).numpy()))

            test_accuracy += np.sum(
                np.abs((outputs.data - labels.data).numpy()) <= 0.5)
            test_num += np.sum(
                np.abs((outputs.data - labels.data).numpy()) != 100000)

        training_accuracy /= training_num
        test_accuracy /= test_num

        if ((epoch + 1) % 1 == 0):
            print(
                '%d loss: %.3f, training_accuracy: %.5f, test_accuracy: %.5f' %
                (epoch + 1, running_loss, training_accuracy, test_accuracy))
            #print(output)

        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, 'jikkenpath')

        #loss_record.append(running_loss)
    else:
        print(training_num)
        print(test_num)
    #print(loss_record)

    if (1):
        test_x, test_t = mkTestModelset(input_len)
        test_x = torch.Tensor(test_x)
        test_t = torch.Tensor(test_t)
        dataset = TensorDataset(test_x, test_t)
        loader_test = DataLoader(dataset, batch_size=5, shuffle=False)
        test_accuracy = 0.0
        test_num = 0
        av_testac = 0
        model.eval()

        for i, data in enumerate(loader_test, 0):
            inputs, labels = data

            outputs = model(inputs)

            labels = labels.float()

            test_accuracy = 0.0
            test_num = 0
            test_accuracy += np.sum(
                np.abs((outputs.data - labels.data).numpy()) <= 0.5)
            test_num += np.sum(
                np.abs((outputs.data - labels.data).numpy()) != 100000)

            test_accuracy /= test_num
            av_testac += test_accuracy

            print(i, test_accuracy)
        else:
            print(av_testac / (i + 1))

    torch.save(model.state_dict(), 'weight.pth')
Ejemplo n.º 26
0
def train(train_dir, model_dir, config_path, checkpoint_path,
          n_steps, save_every, test_every, decay_every,
          n_speakers, n_utterances, seg_len):
    """Train a d-vector network."""

    # setup
    total_steps = 0

    # load data
    dataset = SEDataset(train_dir, n_utterances, seg_len)
    train_set, valid_set = random_split(dataset, [len(dataset)-2*n_speakers,
                                                  2*n_speakers])
    train_loader = DataLoader(train_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    train_iter = iter(train_loader)

    assert len(train_set) >= n_speakers
    assert len(valid_set) >= n_speakers
    print(f"Training starts with {len(train_set)} speakers. "
          f"(and {len(valid_set)} speakers for validation)")

    # build network and training tools
    dvector = DVector().load_config_file(config_path)
    criterion = GE2ELoss()
    optimizer = SGD(list(dvector.parameters()) +
                    list(criterion.parameters()), lr=0.01)
    scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5)

    # load checkpoint
    if checkpoint_path is not None:
        ckpt = torch.load(checkpoint_path)
        total_steps = ckpt["total_steps"]
        dvector.load_state_dict(ckpt["state_dict"])
        criterion.load_state_dict(ckpt["criterion"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])

    # prepare for training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dvector = dvector.to(device)
    criterion = criterion.to(device)
    writer = SummaryWriter(model_dir)
    pbar = tqdm.trange(n_steps)

    # start training
    for step in pbar:

        total_steps += 1

        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)

        embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)

        loss = criterion(embd)

        optimizer.zero_grad()
        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            list(dvector.parameters()) + list(criterion.parameters()), max_norm=3)
        dvector.embedding.weight.grad.data *= 0.5
        criterion.w.grad.data *= 0.01
        criterion.b.grad.data *= 0.01

        optimizer.step()
        scheduler.step()

        pbar.set_description(f"global = {total_steps}, loss = {loss:.4f}")
        writer.add_scalar("Training loss", loss, total_steps)
        writer.add_scalar("Gradient norm", grad_norm, total_steps)

        if (step + 1) % test_every == 0:
            batch = next(iter(valid_loader))
            embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)
            loss = criterion(embd)
            writer.add_scalar("validation loss", loss, total_steps)

        if (step + 1) % save_every == 0:
            ckpt_path = os.path.join(model_dir, f"ckpt-{total_steps}.tar")
            ckpt_dict = {
                "total_steps": total_steps,
                "state_dict": dvector.state_dict(),
                "criterion": criterion.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }
            torch.save(ckpt_dict, ckpt_path)

    print("Training completed.")
Ejemplo n.º 27
0
class LightHeadRCNN_Learner(Module):
    def __init__(self, training=True):
        super(LightHeadRCNN_Learner, self).__init__()
        self.conf = Config()
        self.class_2_color = get_class_colors(self.conf)   

        self.extractor = ResNet101Extractor(self.conf.pretrained_model_path).to(self.conf.device)
        self.rpn = RegionProposalNetwork().to(self.conf.device)
#         self.head = LightHeadRCNNResNet101_Head(self.conf.class_num + 1, self.conf.roi_size).to(self.conf.device)
        self.loc_normalize_mean=(0., 0., 0., 0.),
        self.loc_normalize_std=(0.1, 0.1, 0.2, 0.2)
        self.head = LightHeadRCNNResNet101_Head(self.conf.class_num + 1, 
                                                self.conf.roi_size, 
                                                roi_align = self.conf.use_roi_align).to(self.conf.device)
        self.class_2_color = get_class_colors(self.conf)
        self.detections = namedtuple('detections', ['roi_cls_locs', 'roi_scores', 'rois'])
             
        if training:
            self.train_dataset = coco_dataset(self.conf, mode = 'train')
            self.train_length = len(self.train_dataset)
            self.val_dataset =  coco_dataset(self.conf, mode = 'val')
            self.val_length = len(self.val_dataset)
            self.anchor_target_creator = AnchorTargetCreator()
            self.proposal_target_creator = ProposalTargetCreator(loc_normalize_mean = self.loc_normalize_mean, 
                                                                 loc_normalize_std = self.loc_normalize_std)
            self.step = 0
            self.optimizer = SGD([
                {'params' : get_trainables(self.extractor.parameters())},
                {'params' : self.rpn.parameters()},
                {'params' : [*self.head.parameters()][:8], 'lr' : self.conf.lr*3},
                {'params' : [*self.head.parameters()][8:]},
            ], lr = self.conf.lr, momentum=self.conf.momentum, weight_decay=self.conf.weight_decay)
            self.base_lrs = [params['lr'] for params in self.optimizer.param_groups]
            self.warm_up_duration = 5000
            self.warm_up_rate = 1 / 5
            self.train_outputs = namedtuple('train_outputs',
                                            ['loss_total', 
                                             'rpn_loc_loss', 
                                             'rpn_cls_loss', 
                                             'ohem_roi_loc_loss', 
                                             'ohem_roi_cls_loss',
                                             'total_roi_loc_loss',
                                             'total_roi_cls_loss'])                                      
            self.writer = SummaryWriter(self.conf.log_path)
            self.board_loss_every = self.train_length // self.conf.board_loss_interval
            self.evaluate_every = self.train_length // self.conf.eval_interval
            self.eva_on_coco_every = self.train_length // self.conf.eval_coco_interval
            self.board_pred_image_every = self.train_length // self.conf.board_pred_image_interval
            self.save_every = self.train_length // self.conf.save_interval
            # only for debugging
#             self.board_loss_every = 5
#             self.evaluate_every = 6
#             self.eva_on_coco_every = 7
#             self.board_pred_image_every = 8
#             self.save_every = 10
        
    def set_training(self):
        self.train()
        self.extractor.set_bn_eval()
        
    def lr_warmup(self):
        assert self.step <= self.warm_up_duration, 'stop warm up after {} steps'.format(self.warm_up_duration)
        rate = self.warm_up_rate + (1 - self.warm_up_rate) * self.step / self.warm_up_duration
        for i, params in enumerate(self.optimizer.param_groups):
            params['lr'] = self.base_lrs[i] * rate
           
    def lr_schedule(self, epoch):
        if epoch < 13:
            return
        elif epoch < 16:
            rate = 0.1
        else:
            rate = 0.01
        for i, params in enumerate(self.optimizer.param_groups):
            params['lr'] = self.base_lrs[i] * rate
        print(self.optimizer)
    
    def forward(self, img_tensor, scale, bboxes=None, labels=None, force_eval=False):
        img_tensor = img_tensor.to(self.conf.device)
        img_size = (img_tensor.shape[2], img_tensor.shape[3]) # H,W
        rpn_feature, roi_feature = self.extractor(img_tensor)
        rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(rpn_feature, img_size, scale)
        if self.training or force_eval:
            gt_rpn_loc, gt_rpn_labels = self.anchor_target_creator(bboxes, anchor, img_size)
            gt_rpn_labels = torch.tensor(gt_rpn_labels, dtype=torch.long).to(self.conf.device)
            if len(bboxes) == 0:                
                rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_rpn_labels, ignore_index = -1)
                return self.train_outputs(rpn_cls_loss, 0, 0, 0, 0, 0, 0)
            sample_roi, gt_roi_locs, gt_roi_labels = self.proposal_target_creator(rois, bboxes, labels)
            roi_cls_locs, roi_scores = self.head(roi_feature, sample_roi)
#             roi_cls_locs, roi_scores, pool, h, rois = self.head(roi_feature, sample_roi)
            
            gt_rpn_loc = torch.tensor(gt_rpn_loc, dtype=torch.float).to(self.conf.device)
            gt_roi_locs = torch.tensor(gt_roi_locs, dtype=torch.float).to(self.conf.device)
            gt_roi_labels = torch.tensor(gt_roi_labels, dtype=torch.long).to(self.conf.device)
            
            rpn_loc_loss = fast_rcnn_loc_loss(rpn_locs[0], gt_rpn_loc, gt_rpn_labels, sigma=self.conf.rpn_sigma)
            
            rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_rpn_labels, ignore_index = -1)
            
            ohem_roi_loc_loss, \
            ohem_roi_cls_loss, \
            total_roi_loc_loss, \
            total_roi_cls_loss = OHEM_loss(roi_cls_locs, 
                                           roi_scores, 
                                           gt_roi_locs, 
                                           gt_roi_labels, 
                                           self.conf.n_ohem_sample, 
                                           self.conf.roi_sigma)
            
            loss_total = rpn_loc_loss + rpn_cls_loss + ohem_roi_loc_loss + ohem_roi_cls_loss
            
#             if loss_total.item() > 1000.:
#                 print('ohem_roi_loc_loss : {}, ohem_roi_cls_loss : {}'.format(ohem_roi_loc_loss, ohem_roi_cls_loss))
#                 torch.save(pool, 'pool_debug.pth')
#                 torch.save(h, 'h_debug.pth')
#                 np.save('rois_debug', rois)
#                 torch.save(roi_cls_locs, 'roi_cls_locs_debug.pth')
#                 torch.save(roi_scores, 'roi_scores_debug.pth')
#                 torch.save(gt_roi_locs, 'gt_roi_locs_debug.pth')
#                 torch.save(gt_roi_labels, 'gt_roi_labels_debug.pth')
#                 pdb.set_trace()
            
            return self.train_outputs(loss_total, 
                                      rpn_loc_loss.item(), 
                                      rpn_cls_loss.item(), 
                                      ohem_roi_loc_loss.item(), 
                                      ohem_roi_cls_loss.item(),
                                      total_roi_loc_loss,
                                      total_roi_cls_loss)
        
        else:
            roi_cls_locs, roi_scores = self.head(roi_feature, rois)
            return self.detections(roi_cls_locs, roi_scores, rois)
        
    def eval_predict(self, img, preset = 'evaluate', use_softnms = False):
        if type(img) == list:
            img = img[0]
        img = Image.fromarray(img.transpose(1,2,0).astype('uint8'))
        bboxes, labels, scores = self.predict_on_img(img, preset, use_softnms, original_size = True)
        bboxes = y1x1y2x2_2_x1y1x2y2(bboxes)
        return [bboxes], [labels], [scores]
        
    def predict_on_img(self, img, preset = 'evaluate', use_softnms=False, return_img = False, with_scores = False, original_size = False):
        '''
        inputs :
        imgs : PIL Image
        return : PIL Image (if return_img) or bboxes_group and labels_group
        '''
        self.eval()
        self.use_preset(preset)
        with torch.no_grad():
            orig_size = img.size # W,H
            img = np.asarray(img).transpose(2,0,1)
            img, scale = prepare_img(self.conf, img, -1)
            img = torch.tensor(img).unsqueeze(0)
            img_size = (img.shape[2], img.shape[3]) # H,W
            detections = self.forward(img, scale)
            n_sample = len(detections.roi_cls_locs)
            n_class = self.conf.class_num + 1
            roi_cls_locs = detections.roi_cls_locs.reshape((n_sample, -1, 4)).reshape([-1,4])
            roi_cls_locs = roi_cls_locs * torch.tensor(self.loc_normalize_std, device=self.conf.device) + torch.tensor(self.loc_normalize_mean, device=self.conf.device)
            rois = torch.tensor(detections.rois.repeat(n_class,0), dtype=torch.float).to(self.conf.device)
            raw_cls_bboxes = loc2bbox(rois, roi_cls_locs)
            torch.clamp(raw_cls_bboxes[:,0::2], 0, img_size[1], out = raw_cls_bboxes[:,0::2] )
            torch.clamp(raw_cls_bboxes[:,1::2], 0, img_size[0], out = raw_cls_bboxes[:,1::2] )
            raw_cls_bboxes = raw_cls_bboxes.reshape([n_sample, n_class, 4])
            raw_prob = F.softmax(detections.roi_scores, dim=1)
            bboxes, labels, scores = self._suppress(raw_cls_bboxes, raw_prob, use_softnms)
            if len(bboxes) == len(labels) == len(scores) == 0:
                if not return_img:  
                    return [], [], []
                else:
                    return to_img(self.conf, img[0])
            _, indices = scores.sort(descending=True)
            bboxes = bboxes[indices]
            labels = labels[indices]
            scores = scores[indices]
            if len(bboxes) > self.max_n_predict:
                bboxes = bboxes[:self.max_n_predict]
                labels = labels[:self.max_n_predict]
                scores = scores[:self.max_n_predict]
        # now, implement drawing
        bboxes = bboxes.cpu().numpy()
        labels = labels.cpu().numpy()
        scores = scores.cpu().numpy()
        if original_size:
            bboxes = adjust_bbox(scale, bboxes, detect=True)
        if not return_img:        
            return bboxes, labels, scores
        else:
            if with_scores:
                scores_ = scores
            else:
                scores_ = []
            predicted_img =  to_img(self.conf, img[0])
            if original_size:
                predicted_img = predicted_img.resize(orig_size)
            if len(bboxes) != 0 and len(labels) != 0:
                predicted_img = draw_bbox_class(self.conf, 
                                                predicted_img, 
                                                labels, 
                                                bboxes, 
                                                self.conf.correct_id_2_class, 
                                                self.class_2_color, 
                                                scores = scores_)
            
            return predicted_img
    
    def _suppress(self, raw_cls_bboxes, raw_prob, use_softnms):
        bbox = []
        label = []
        prob = []
        for l in range(1, self.conf.class_num + 1):
            cls_bbox_l = raw_cls_bboxes[:, l, :]
            prob_l = raw_prob[:, l]
            mask = prob_l > self.score_thresh
            if not mask.any():
                continue
            cls_bbox_l = cls_bbox_l[mask]
            prob_l = prob_l[mask]
            if use_softnms:
                keep, _  = soft_nms(torch.cat((cls_bbox_l, prob_l.unsqueeze(-1)), dim=1).cpu().numpy(),
                                    Nt = self.conf.softnms_Nt,
                                    method = self.conf.softnms_method,
                                    sigma = self.conf.softnms_sigma,
                                    min_score = self.conf.softnms_min_score)
                keep = keep.tolist()
            else:
#                 prob_l, order = torch.sort(prob_l, descending=True)
#                 cls_bbox_l = cls_bbox_l[order]
                keep = nms(torch.cat((cls_bbox_l, prob_l.unsqueeze(-1)), dim=1), self.nms_thresh).tolist()
            bbox.append(cls_bbox_l[keep])
            # The labels are in [0, 79].
            label.append((l - 1) * torch.ones((len(keep),), dtype = torch.long))
            prob.append(prob_l[keep])
        if len(bbox) == 0:
            print("looks like there is no prediction have a prob larger than thresh")
            return [], [], []
        bbox = torch.cat(bbox)
        label = torch.cat(label)
        prob = torch.cat(prob)
        return bbox, label, prob
    
    def board_scalars(self, 
                      key, 
                      loss_total, 
                      rpn_loc_loss, 
                      rpn_cls_loss, 
                      ohem_roi_loc_loss, 
                      ohem_roi_cls_loss, 
                      total_roi_loc_loss, 
                      total_roi_cls_loss):
        self.writer.add_scalar('{}_loss_total'.format(key), loss_total, self.step)
        self.writer.add_scalar('{}_rpn_loc_loss'.format(key), rpn_loc_loss, self.step)
        self.writer.add_scalar('{}_rpn_cls_loss'.format(key), rpn_cls_loss, self.step)
        self.writer.add_scalar('{}_ohem_roi_loc_loss'.format(key), ohem_roi_loc_loss, self.step)
        self.writer.add_scalar('{}_ohem_roi_cls_loss'.format(key), ohem_roi_cls_loss, self.step)
        self.writer.add_scalar('{}_total_roi_loc_loss'.format(key), total_roi_loc_loss, self.step)
        self.writer.add_scalar('{}_total_roi_cls_loss'.format(key), total_roi_cls_loss, self.step)
    
    def use_preset(self, preset):
        """Use the given preset during prediction.

        This method changes values of :obj:`self.nms_thresh` and
        :obj:`self.score_thresh`. These values are a threshold value
        used for non maximum suppression and a threshold value
        to discard low confidence proposals in :meth:`predict`,
        respectively.

        If the attributes need to be changed to something
        other than the values provided in the presets, please modify
        them by directly accessing the public attributes.

        Args:
            preset ({'visualize', 'evaluate', 'debug'): A string to determine the
                preset to use.

        """
        if preset == 'visualize':
            self.nms_thresh = 0.5
            self.score_thresh = 0.25
            self.max_n_predict = 40
        elif preset == 'evaluate':
            self.nms_thresh = 0.5
            self.score_thresh = 0.0
            self.max_n_predict = 100
#         """
#         We finally replace origi-nal 0.3 threshold with 0.5 for Non-maximum Suppression
#         (NMS). It improves 0.6 points of mmAP by improving the
#         recall rate especially for the crowd cases.
#         """
        elif preset == 'debug':
            self.nms_thresh = 0.5
            self.score_thresh = 0.0
            self.max_n_predict = 10
        else:
            raise ValueError('preset must be visualize or evaluate')
    
    def fit(self, epochs=20, resume=False, from_save_folder=False):
        if resume:
            self.resume_training_load(from_save_folder)
        self.set_training()        
        running_loss = 0.
        running_rpn_loc_loss = 0.
        running_rpn_cls_loss = 0.
        running_ohem_roi_loc_loss = 0.
        running_ohem_roi_cls_loss = 0.
        running_total_roi_loc_loss = 0.
        running_total_roi_cls_loss = 0.
        map05 = None
        val_loss = None
        
        epoch = self.step // self.train_length
        while epoch <= epochs:
            print('start the training of epoch : {}'.format(epoch))
            self.lr_schedule(epoch)
#             for index in tqdm(np.random.permutation(self.train_length), total = self.train_length):
            for index in tqdm(range(self.train_length), total = self.train_length):
                try:
                    inputs = self.train_dataset[index]
                except:
                    print('loading index {} from train dataset failed}'.format(index))
#                     print(self.train_dataset.orig_dataset._datasets[0].id_to_prop[self.train_dataset.orig_dataset._datasets[0].ids[index]])
                    continue
                self.optimizer.zero_grad()
                train_outputs = self.forward(torch.tensor(inputs.img).unsqueeze(0),
                                             inputs.scale,
                                             inputs.bboxes,
                                             inputs.labels)
                train_outputs.loss_total.backward()
                if epoch == 0:
                    if self.step <= self.warm_up_duration:
                        self.lr_warmup()
                self.optimizer.step()
                torch.cuda.empty_cache()
                
                running_loss += train_outputs.loss_total.item()
                running_rpn_loc_loss += train_outputs.rpn_loc_loss
                running_rpn_cls_loss += train_outputs.rpn_cls_loss
                running_ohem_roi_loc_loss += train_outputs.ohem_roi_loc_loss
                running_ohem_roi_cls_loss += train_outputs.ohem_roi_cls_loss
                running_total_roi_loc_loss += train_outputs.total_roi_loc_loss
                running_total_roi_cls_loss += train_outputs.total_roi_cls_loss
                
                if self.step != 0:
                    if self.step % self.board_loss_every == 0:
                        self.board_scalars('train', 
                                           running_loss / self.board_loss_every, 
                                           running_rpn_loc_loss / self.board_loss_every, 
                                           running_rpn_cls_loss / self.board_loss_every,
                                           running_ohem_roi_loc_loss / self.board_loss_every, 
                                           running_ohem_roi_cls_loss / self.board_loss_every,
                                           running_total_roi_loc_loss / self.board_loss_every, 
                                           running_total_roi_cls_loss / self.board_loss_every)
                        running_loss = 0.
                        running_rpn_loc_loss = 0.
                        running_rpn_cls_loss = 0.
                        running_ohem_roi_loc_loss = 0.
                        running_ohem_roi_cls_loss = 0.
                        running_total_roi_loc_loss = 0.
                        running_total_roi_cls_loss = 0.

                    if self.step % self.evaluate_every == 0:
                        val_loss, val_rpn_loc_loss, \
                        val_rpn_cls_loss, \
                        ohem_val_roi_loc_loss, \
                        ohem_val_roi_cls_loss, \
                        total_val_roi_loc_loss, \
                        total_val_roi_cls_loss = self.evaluate(num = self.conf.eva_num_during_training)
                        self.set_training() 
                        self.board_scalars('val', 
                                           val_loss, 
                                           val_rpn_loc_loss, 
                                           val_rpn_cls_loss, 
                                           ohem_val_roi_loc_loss,
                                           ohem_val_roi_cls_loss,
                                           total_val_roi_loc_loss,
                                           total_val_roi_cls_loss)
                    
                    if self.step % self.eva_on_coco_every == 0:
                        try:
                            cocoEval = self.eva_on_coco(limit = self.conf.coco_eva_num_during_training)
                            self.set_training() 
                            map05 = cocoEval[1]
                            mmap = cocoEval[0]
                        except:
                            print('eval on coco failed')
                            map05 = -1
                            mmap = -1
                        self.writer.add_scalar('0.5IoU MAP', map05, self.step)
                        self.writer.add_scalar('0.5::0.9 - MMAP', mmap, self.step)
                    
                    if self.step % self.board_pred_image_every == 0:
                        for i in range(20):
                            img, _, _, _ , _= self.val_dataset.orig_dataset[i]  
                            img = Image.fromarray(img.astype('uint8').transpose(1,2,0))
                            predicted_img = self.predict_on_img(img, preset='visualize', return_img=True, with_scores=True, original_size=True) 
#                             if type(predicted_img) == tuple: 
#                                 self.writer.add_image('pred_image_{}'.format(i), trans.ToTensor()(img), global_step=self.step)
#                             else: ## should be deleted after test
                            self.writer.add_image('pred_image_{}'.format(i), trans.ToTensor()(predicted_img), global_step=self.step)
                            self.set_training()
                    
                    if self.step % self.save_every == 0:
                        try:
                            self.save_state(val_loss, map05)
                        except:
                            print('save state failed')
                            self.step += 1
                            continue
                    
                self.step += 1
            epoch = self.step // self.train_length
            try:
                self.save_state(val_loss, map05, to_save_folder=True)
            except:
                print('save state failed')
    
    def eva_on_coco(self, limit = 1000, preset = 'evaluate', use_softnms = False):
        self.eval() 
        return eva_coco(self.val_dataset.orig_dataset, lambda x : self.eval_predict(x, preset, use_softnms), limit, preset)
    
    def evaluate(self, num=None):
        self.eval()        
        running_loss = 0.
        running_rpn_loc_loss = 0.
        running_rpn_cls_loss = 0.
        running_ohem_roi_loc_loss = 0.
        running_ohem_roi_cls_loss = 0.
        running_total_roi_loc_loss = 0.
        running_total_roi_cls_loss = 0.
        if num == None:
            total_num = self.val_length
        else:
            total_num = num
        with torch.no_grad():
            for index in tqdm(range(total_num)):
                inputs = self.val_dataset[index]
                if inputs.bboxes == []:
                    continue
                val_outputs = self.forward(torch.tensor(inputs.img).unsqueeze(0),
                                           inputs.scale,
                                           inputs.bboxes,
                                           inputs.labels,
                                           force_eval = True)
                running_loss += val_outputs.loss_total.item()
                running_rpn_loc_loss += val_outputs.rpn_loc_loss
                running_rpn_cls_loss += val_outputs.rpn_cls_loss
                running_ohem_roi_loc_loss += val_outputs.ohem_roi_loc_loss
                running_ohem_roi_cls_loss += val_outputs.ohem_roi_cls_loss
                running_total_roi_loc_loss += val_outputs.total_roi_loc_loss
                running_total_roi_cls_loss += val_outputs.total_roi_cls_loss
        return running_loss / total_num, \
                running_rpn_loc_loss / total_num, \
                running_rpn_cls_loss / total_num, \
                running_ohem_roi_loc_loss / total_num, \
                running_ohem_roi_cls_loss / total_num,\
                running_total_roi_loc_loss / total_num, \
                running_total_roi_cls_loss / total_num
    
    def save_state(self, val_loss, map05, to_save_folder=False, model_only=False):
        if to_save_folder:
            save_path = self.conf.work_space/'save'
        else:
            save_path = self.conf.work_space/'model'
        time = get_time()
        torch.save(
            self.state_dict(), save_path /
            ('model_{}_val_loss:{}_map05:{}_step:{}.pth'.format(time,
                                                                val_loss, 
                                                                map05, 
                                                                self.step)))
        if not model_only:
            torch.save(
                self.optimizer.state_dict(), save_path /
                ('optimizer_{}_val_loss:{}_map05:{}_step:{}.pth'.format(time,
                                                                        val_loss, 
                                                                        map05, 
                                                                        self.step)))
    
    def load_state(self, fixed_str, from_save_folder=False, model_only=False):
        if from_save_folder:
            save_path = self.conf.work_space/'save'
        else:
            save_path = self.conf.work_space/'model'          
        self.load_state_dict(torch.load(save_path/'model_{}'.format(fixed_str)))
        print('load model_{}'.format(fixed_str))
        if not model_only:
            self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str)))
            print('load optimizer_{}'.format(fixed_str))
    
    def resume_training_load(self, from_save_folder=False):
        if from_save_folder:
            save_path = self.conf.work_space/'save'
        else:
            save_path = self.conf.work_space/'model'  
        sorted_files = sorted([*save_path.iterdir()],  key=lambda x: os.path.getmtime(x), reverse=True)
        seeking_flag = True
        index = 0
        while seeking_flag:
            if index > len(sorted_files) - 2:
                break
            file_a = sorted_files[index]
            file_b = sorted_files[index + 1]
            if file_a.name.startswith('model'):
                fix_str = file_a.name[6:]
                self.step = int(fix_str.split(':')[-1].split('.')[0]) + 1
                if file_b.name == ''.join(['optimizer', '_', fix_str]):                    
                    self.load_state(fix_str, from_save_folder)
                    return
                else:
                    index += 1
                    continue
            elif file_a.name.startswith('optimizer'):
                fix_str = file_a.name[10:]
                self.step = int(fix_str.split(':')[-1].split('.')[0]) + 1
                if file_b.name == ''.join(['model', '_', fix_str]):
                    self.load_state(fix_str, from_save_folder)
                    return
                else:
                    index += 1
                    continue
            else:
                index += 1
                continue
        print('no available files founded')
        return      
Ejemplo n.º 28
0
def train(model,
          state,
          path,
          annotations,
          val_path,
          val_annotations,
          resize,
          max_size,
          jitter,
          batch_size,
          iterations,
          val_iterations,
          mixed_precision,
          lr,
          warmup,
          milestones,
          gamma,
          is_master=True,
          world=1,
          use_dali=True,
          verbose=True,
          metrics_url=None,
          logdir=None):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(),
                    lr=lr,
                    weight_decay=0.0001,
                    momentum=0.9)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level='O2' if mixed_precision else 'O0',
        keep_batchnorm_fp32=True,
        loss_scale=128.0,
        verbosity=is_master)

    if world > 1:
        model = DistributedDataParallel(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer.optimizer if mixed_precision else optimizer,
                         schedule)

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    data_iterator = (DaliDataIterator if use_dali else DataIterator)(
        path,
        jitter,
        max_size,
        batch_size,
        stride,
        world,
        annotations,
        training=True)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'gpu' if world == 1 else 'gpus'))
        print('    batch: {}, precision: {}'.format(
            batch_size, 'mixed' if mixed_precision else 'full'))
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if logdir is not None:
        from tensorboardX import SummaryWriter
        if is_master and verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(log_dir=logdir)

    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            scheduler.step(iteration)

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            cls_loss, box_loss = model([data, target])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean(
            ).clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60
                              or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration,
                                                 iterations,
                                                 len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'],
                                                       batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(
                        profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size /
                                                  profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate,
                                      iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(
                        metrics_url, {
                            'focal loss': mean(cls_losses),
                            'box loss': mean(box_losses),
                            'im_s': batch_size / profiler.means['train'],
                            'lr': learning_rate
                        })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations
                                    or iteration % val_iterations == 0):
                infer(model,
                      val_path,
                      None,
                      resize,
                      max_size,
                      batch_size,
                      annotations=val_annotations,
                      mixed_precision=mixed_precision,
                      is_master=is_master,
                      world=world,
                      use_dali=use_dali,
                      verbose=False)
                model.train()

            if iteration == iterations:
                break

    if logdir is not None:
        writer.close()
Ejemplo n.º 29
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    model = create_model(args.arch,
                         args.pretrained,
                         args.finetune,
                         num_classes=args.num_classes)

    # define loss function (criterion) and optimizer
    criterion = CrossEntropyLoss().cuda()

    optimizer = SGD(
        filter(lambda p: p.requires_grad,
               model.parameters()),  # Only finetunable params
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            load_model_from_checkpoint(args, model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # data loading
    train_path = os.path.join(args.data, 'train')
    test_path = os.path.join(args.data, 'test')
    if os.path.exists(train_path):
        train_loader = read_fer2013_data(train_path,
                                         dataset_type='train',
                                         batch_size=args.batch_size,
                                         num_workers=args.workers)
    if os.path.exists(test_path):
        test_loader = read_fer2013_data(test_path,
                                        dataset_type='test',
                                        batch_size=args.batch_size,
                                        num_workers=args.workers)

    if args.evaluate:
        test(test_loader, model, criterion, args.print_freq)
        return

    summary_writer = SummaryWriter()
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(args.lr, optimizer, epoch, args.lr_decay,
                             args.lr_decay_freq)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch,
              args.print_freq, summary_writer)

        # evaluate on test set
        prec1 = test(test_loader, model, criterion, args.print_freq)

        # remember best prec@1 and save all checkpoints
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, is_best)
    summary_writer.close()
Ejemplo n.º 30
0
def train_model(train_loader,
                test_loader,
                device,
                lr,
                epochs,
                output_path,
                valid_loader=False):
    model = CNN().to(device)
    optimizer = SGD(model.parameters(), lr=lr)

    average_loss_train = []
    average_loss_test = []

    accuracy_train = []
    accuracy_test = []

    for epoch in range(epochs):
        model.train()
        correct_train, loss_train, _ = loop_dataset(model, train_loader,
                                                    device, optimizer)

        print(
            f'Epoch {epoch} : average train loss - {np.mean(loss_train)}, train accuracy - {correct_train}'
        )

        average_loss_train.append(np.mean(loss_train))
        accuracy_train.append(correct_train)

        model.eval()
        correct_test, loss_test, _ = loop_dataset(model, test_loader, device)

        print(
            f'Epoch {epoch} : average test loss - {np.mean(loss_test)}, test accuracy - {correct_test}'
        )

        average_loss_test.append(np.mean(loss_test))
        accuracy_test.append(correct_test)

    model.eval()

    for i in range(0, len(model.layers)):
        model.layers[i].register_forward_hook(forward_hook)
    if valid_loader:
        correct_valid, _, output = loop_dataset(model, valid_loader, device)

        print('\033[99m' + f'Accuracy on VALID test: {correct_valid}' +
              '\033[0m')

    checkpoint = {
        'model': CNN(),
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(checkpoint, os.path.join(output_path, 'checkpoint.pth'))

    plt.figure()
    plt.plot(range(epochs), average_loss_train, lw=0.3, c='g')
    plt.plot(range(epochs), average_loss_test, lw=0.3, c='r')
    plt.legend(['train loss', 'test_loss'])
    plt.xlabel('#Epoch')
    plt.ylabel('Loss')
    plt.savefig(jpath(output_path, 'loss.png'))

    plt.figure()
    plt.plot(range(epochs), accuracy_train, lw=0.3, c='g')
    plt.plot(range(epochs), accuracy_test, lw=0.3, c='r')
    plt.legend(['train_acc', 'test_acc'])
    plt.xlabel('#Epoch')
    plt.ylabel('Accuracy')
    plt.savefig(jpath(output_path, 'accuracy.png'))