def main(opts):
    if not os.path.exists(opts.load_dir):
        os.mkdir(opts.load_dir)
    train_dataset = MultiBddDetection('dataset/bdd_detection',
                                      split="train",
                                      scales=[1, 0.5, 0.25])
    train_loader = DataLoader(train_dataset,
                              batch_size=opts.batch_size,
                              shuffle=True,
                              num_workers=opts.num_workers)

    test_dataset = MultiBddDetection('dataset/bdd_detection',
                                     split='val',
                                     scales=[1, 0.5, 0.25])
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=opts.batch_size,
                             num_workers=opts.num_workers)

    attention_models = []
    for s in opts.scales:
        attention_model = AttentionModelBddDetection(squeeze_channels=True,
                                                     softmax_smoothing=1e-4)
        attention_models.append(attention_model)
    feature_model = FeatureModelBddDetection(in_channels=3,
                                             strides=[1, 2, 2, 2],
                                             filters=[32, 32, 32, 32])
    classification_head = ClassificationHead(in_channels=32,
                                             num_classes=len(
                                                 train_dataset.CLASSES))

    ats_model = None
    logger = None
    if opts.map_parallel:
        print("Run parallel model.")
        print("n patches for high res, and another n for low res.")
        ats_model = MultiParallelATSModel(attention_model,
                                          feature_model,
                                          classification_head,
                                          n_patches=opts.n_patches,
                                          patch_size=opts.patch_size,
                                          scales=opts.scales)
        ats_model = ats_model.to(opts.device)

        logger = AttentionSaverMultiParallelBddDetection(
            opts.output_dir, ats_model, test_dataset, opts)

    else:
        print("Run unparallel model.")
        attention_model = AttentionModelMultiBddDetection(
            squeeze_channels=True, softmax_smoothing=1e-4)
        if opts.area_norm:
            print("Merge before softmax with area normalization.")
            ats_model = MultiATSModel(attention_model,
                                      feature_model,
                                      classification_head,
                                      n_patches=opts.n_patches,
                                      patch_size=opts.patch_size,
                                      scales=opts.scales,
                                      area_norm=True)
        else:
            print("Merge before softmax without area normalization.")
            ats_model = MultiATSModel(attention_model,
                                      feature_model,
                                      classification_head,
                                      n_patches=opts.n_patches,
                                      patch_size=opts.patch_size,
                                      scales=opts.scales,
                                      area_norm=False)
        ats_model = ats_model.to(opts.device)

        logger = AttentionSaverMultiBddDetection(opts.output_dir, ats_model,
                                                 test_dataset, opts)

    # ats_model = ats_model.to(opts.device)
    optimizer = optim.Adam(
        [{
            'params': ats_model.attention_model.part1.parameters(),
            'weight_decay': 1e-5
        }, {
            'params': ats_model.attention_model.part2.parameters()
        }, {
            'params': ats_model.feature_model.parameters()
        }, {
            'params': ats_model.classifier.parameters()
        }, {
            'params': ats_model.sampler.parameters()
        }, {
            'params': ats_model.expectation.parameters()
        }],
        lr=opts.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=opts.decrease_lr_at,
                                                gamma=0.1)

    class_weights = train_dataset.class_frequencies
    class_weights = torch.from_numpy(
        (1. / len(class_weights)) / class_weights).to(opts.device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    entropy_loss_func = MultinomialEntropy(opts.regularizer_strength)

    start_epoch = 0
    opts.checkpoint_path = os.path.join(opts.output_dir, "checkpoint")
    if not os.path.exists(opts.checkpoint_path):
        os.mkdir(opts.checkpoint_path)
    if opts.resume:
        # start_epoch = opts.load_epoch + 1
        ats_model, optimizer, start_epoch = load_checkpoint(
            ats_model, optimizer,
            os.path.join(opts.load_dir,
                         "checkpoint{:02d}.pth".format(opts.load_epoch)))
        start_epoch += 1
        print("load %s successfully." % (os.path.join(
            opts.load_dir, "checkpoint{:02d}.pth".format(opts.load_epoch))))
    else:
        print("nothing to load.")

    for epoch in range(start_epoch, opts.epochs):
        print("Start epoch %d" % epoch)
        train_loss, train_metrics = trainMultiRes(ats_model, optimizer,
                                                  train_loader, criterion,
                                                  entropy_loss_func, opts)
        if epoch % 2 == 0:
            save_checkpoint(
                ats_model, optimizer,
                os.path.join(opts.checkpoint_path,
                             "checkpoint{:02d}.pth".format(epoch)), epoch)
            print("Save " + os.path.join(
                opts.checkpoint_path, "checkpoint{:02d}.pth".format(epoch)) +
                  " successfully.")
        print("Epoch {}, train loss: {:.3f}, train metrics: {:.3f}".format(
            epoch, train_loss, train_metrics["accuracy"]))
        with torch.no_grad():
            test_loss, test_metrics = evaluateMultiRes(ats_model, test_loader,
                                                       criterion,
                                                       entropy_loss_func, opts)

        logger(epoch, (train_loss, test_loss), (train_metrics, test_metrics))
        print("Epoch {}, test loss: {:.3f}, test metrics: {:.3f}".format(
            epoch, test_loss, test_metrics["accuracy"]))
        scheduler.step()
Esempio n. 2
0
def train_pipeline(out_dir,
                   weak_perc,
                   n_latents=20,
                   batch_size=128,
                   epochs=20,
                   lr=1e-3,
                   log_interval=10,
                   cuda=False):
    """Pipeline to train and test MultimodalVAE on MNIST dataset. This is 
    identical to the code in train.py.

    :param out_dir: directory to store trained models
    :param weak_perc: percent of time to show a relation pair (vs no relation pair)
    :param n_latents: size of latent variable (default: 20)
    :param batch_size: number of examples to show at once (default: 128)
    :param epochs: number of loops over dataset (default: 20)
    :param lr: learning rate (default: 1e-3)
    :param log_interval: interval of printing (default: 10)
    :param cuda: whether to use cuda or not (default: False)
    """
    # create loaders for MNIST
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data', train=False, download=True, transform=transforms.ToTensor()),
                                              batch_size=batch_size,
                                              shuffle=True)

    # load multimodal VAE
    vae = MultimodalVAE(n_latents=n_latents)
    if cuda:
        vae.cuda()

    optimizer = optim.Adam(vae.parameters(), lr=lr)

    def train(epoch):
        random.seed(42)
        np.random.seed(42)  # important to have the same seed
        # in order to make the same choices for weak supervision
        # otherwise, we end up showing different examples over epochs
        vae.train()

        joint_loss_meter = AverageMeter()
        image_loss_meter = AverageMeter()
        text_loss_meter = AverageMeter()

        for batch_idx, (image, text) in enumerate(train_loader):
            if cuda:
                image, text = image.cuda(), text.cuda()
            image, text = Variable(image), Variable(text)
            image = image.view(-1, 784)  # flatten image
            optimizer.zero_grad()

            # depending on this flip, we either show it a full paired example or
            # we show it single modalities (in which we cannot compute the full loss)
            flip = np.random.random()
            if flip < weak_perc:  # here we show a paired example
                recon_image_1, recon_text_1, mu_1, logvar_1 = vae(image, text)
                loss_1 = loss_function(mu_1,
                                       logvar_1,
                                       recon_image=recon_image_1,
                                       image=image,
                                       recon_text=recon_text_1,
                                       text=text,
                                       lambda_xy=1.,
                                       lambda_yx=1.)
                recon_image_2, recon_text_2, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2,
                                       logvar_2,
                                       recon_image=recon_image_2,
                                       image=image,
                                       recon_text=recon_text_2,
                                       text=text,
                                       lambda_xy=1.,
                                       lambda_yx=1.)
                recon_image_3, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3,
                                       logvar_3,
                                       recon_image=recon_image_3,
                                       image=image,
                                       recon_text=recon_text_3,
                                       text=text,
                                       lambda_xy=0.,
                                       lambda_yx=1.)

                loss = loss_1 + loss_2 + loss_3
                joint_loss_meter.update(loss_1.data[0], len(image))

            else:  # here we show individual modalities
                recon_image_2, _, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2,
                                       logvar_2,
                                       recon_image=recon_image_2,
                                       image=image,
                                       lambda_xy=1.,
                                       lambda_yx=0.)
                _, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3,
                                       logvar_3,
                                       recon_text=recon_text_3,
                                       text=text,
                                       lambda_yx=1.,
                                       lambda_yx=0.)
                loss = loss_2 + loss_3

            image_loss_meter.update(loss_2.data[0], len(image))
            text_loss_meter.update(loss_3.data[0], len(text))

            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print(
                    '[Weak {:.0f}%] Train Epoch: {} [{}/{} ({:.0f}%)]\tJoint Loss: {:.6f}\tImage Loss: {:.6f}\tText Loss: {:.6f}'
                    .format(100. * weak_perc, epoch, batch_idx * len(image),
                            len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                            joint_loss_meter.avg, image_loss_meter.avg,
                            text_loss_meter.avg))

        print(
            '====> [Weak {:.0f}%] Epoch: {} Joint loss: {:.4f}\tImage loss: {:.4f}\tText loss: {:.4f}'
            .format(100. * weak_perc, epoch, joint_loss_meter.avg,
                    image_loss_meter.avg, text_loss_meter.avg))

    def test():
        vae.eval()
        test_joint_loss = 0
        test_image_loss = 0
        test_text_loss = 0

        for batch_idx, (image, text) in enumerate(test_loader):
            if cuda:
                image, text = image.cuda(), text.cuda()
            image, text = Variable(image), Variable(text)
            image = image.view(-1, 784)  # flatten image

            # in test i always care about the joint loss -- so we don't anneal
            # back joint examples as we do in train
            recon_image_1, recon_text_1, mu_1, logvar_1 = vae(image, text)
            recon_image_2, recon_text_2, mu_2, logvar_2 = vae(image=image)
            recon_image_3, recon_text_3, mu_3, logvar_3 = vae(text=text)

            loss_1 = loss_function(mu_1,
                                   logvar_1,
                                   recon_image=recon_image_1,
                                   image=image,
                                   recon_text=recon_text_1,
                                   text=text,
                                   lambda_xy=1.,
                                   lambda_yx=1.)
            loss_2 = loss_function(mu_2,
                                   logvar_2,
                                   recon_image=recon_image_2,
                                   image=image,
                                   recon_text=recon_text_2,
                                   text=text,
                                   lambda_xy=1.,
                                   lambda_yx=1.)
            loss_3 = loss_function(mu_3,
                                   logvar_3,
                                   recon_image=recon_image_3,
                                   image=image,
                                   recon_text=recon_text_3,
                                   text=text,
                                   lambda_xy=0.,
                                   lambda_yx=1.)

            test_joint_loss += loss_1.data[0]
            test_image_loss += loss_2.data[0]
            test_text_loss += loss_3.data[0]

        test_loss = test_joint_loss + test_image_loss + test_text_loss
        test_joint_loss /= len(test_loader)
        test_image_loss /= len(test_loader)
        test_text_loss /= len(test_loader)
        test_loss /= len(test_loader)

        print(
            '====> [Weak {:.0f}%] Test joint loss: {:.4f}\timage loss: {:.4f}\ttext loss:{:.4f}'
            .format(100. * weak_perc, test_joint_loss, test_image_loss,
                    test_text_loss))

        return test_loss, (test_joint_loss, test_image_loss, test_text_loss)

    best_loss = sys.maxint
    for epoch in range(1, epochs + 1):
        train(epoch)
        loss, (joint_loss, image_loss, text_loss) = test()

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)

        save_checkpoint(
            {
                'state_dict': vae.state_dict(),
                'best_loss': best_loss,
                'joint_loss': joint_loss,
                'image_loss': image_loss,
                'text_loss': text_loss,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            folder=out_dir)
Esempio n. 3
0
        print("=> no checkpoint found at '{}'".format(args.resume))

best_prec1 = 0.
for epoch in range(args.start_epoch, args.epochs):
    if epoch in [args.epochs*0.5, args.epochs*0.75]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1
    avg_loss, train_acc = train(
        model, 
        optimizer, 
        epoch=epoch, 
        device=device,
        train_loader=train_loader,
        valid=args.valid, 
        valid_len=valid_len, 
        log_interval=args.log_interval)
    if args.valid:
        prec1 = valid(model, device, valid_loader, valid_len=valid_len)
        test(model, device, test_loader)
    else:
        prec1 = test(model, device, test_loader)
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer': optimizer.state_dict(),
        'cfg': model.cfg
    }, is_best, filepath=args.save)
Esempio n. 4
0
def main():
    global epochs_since_improvement, start_epoch, best_loss, epoch, checkpoint

    #balance three labels sample in ../code/original/fix.py
    train_data = Pneumonia(txt=balanced_txt,
                           mode='train',
                           class_to_idx=class_to_idx,
                           transforms=data_transforms['train'])
    train_data, valid_data, test_data = torch.utils.data.random_split(
        train_data, [7000, 2000, 3000])
    print('train_data size: ', len(train_data))
    print('valid_data_size: ', len(valid_data))
    print('test_data_size: ', len(test_data))

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               num_workers=0,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              num_workers=0,
                                              shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_data,
                                               batch_size=batch_size,
                                               num_workers=0,
                                               shuffle=True)

    # we will use a pretrained model and we are going to change only the last layer
    model = models.densenet121(pretrained=True)
    #model = models.resnet50(pretrained=True)
    for param in model.parameters():
        param.requires_grad = True

    model.classifier = nn.Sequential(
        OrderedDict([
            ('fcl1', nn.Linear(1024, 256)),
            ('dp1', nn.Dropout(0.3)),
            ('r1', nn.ReLU()),
            ('fcl2', nn.Linear(256, 32)),
            ('dp2', nn.Dropout(0.3)),
            ('r2', nn.ReLU()),
            ('fcl3', nn.Linear(32, 3)),
            #('out', nn.Softmax(dim=1)),
        ]))

    train_on_gpu = torch.cuda.is_available()
    if train_on_gpu:
        print('GPU is  available :)   Training on GPU ...')
    else:
        print('GPU is not available :(  Training on CPU ...')

    #need to remove comment after first trainning
    checkpoint = torch.load('/home/tianshu/pneumonia/code/checkpoint.pth.tar',
                            map_location={'cuda:2': 'cuda:0'})
    if checkpoint is None:
        optimizer = optim.Adadelta(model.parameters())
    else:
        #load checkpoint
        #checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epoch_since_improvement']
        best_loss = checkpoint['best_loss']
        print(
            '\nLoaded checkpoint from epoch %d. Best loss so far is %.3f.\n' %
            (start_epoch, best_loss))
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    criterion = nn.CrossEntropyLoss()

    #train the model
    for epoch in range(start_epoch, epochs):
        val_loss = train_function(model,
                                  train_loader,
                                  valid_loader,
                                  criterion=criterion,
                                  optimizer=optimizer,
                                  train_on_gpu=train_on_gpu,
                                  epoch=epoch,
                                  device=device,
                                  scheduler=None)

        # Did validation loss improve?
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)

        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))

        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer,
                        val_loss, best_loss, is_best)

    test_function(model, test_loader, device, criterion, cat_to_name)
Esempio n. 5
0
    # Evaluate on validation set
    plot_data = train.validate(val_loader, model, criterion, print_freq,
                               plot_data, gpu)

    # Remember best model and save checkpoint
    is_best = plot_data['val_loss'][epoch] < best_loss
    if is_best:
        print("New best model by loss. Val Loss = " +
              str(plot_data['val_loss'][epoch]))
        best_loss = plot_data['val_loss'][epoch]
        filename = dataset + '/models/' + training_id + '_epoch_' + str(
            epoch) + '_ValLoss_' + str(round(plot_data['val_loss'][epoch], 2))
        prefix_len = len('_epoch_' + str(epoch) + '_ValLoss_' +
                         str(round(plot_data['val_loss'][epoch], 2)))
        train.save_checkpoint(model, filename, prefix_len)

        # Remember best model and save checkpoint
    is_best = plot_data['val_prec50'][epoch] > best_acc
    if is_best:
        print("New best model by Acc. Val Acc at 50 = " +
              str(plot_data['val_prec50'][epoch]))
        best_acc = plot_data['val_prec50'][epoch]
        filename = dataset + '/models/' + training_id + '_ByACC_epoch_' + str(
            epoch) + '_ValAcc_' + str(round(plot_data['val_prec50'][epoch], 2))
        prefix_len = len('_epoch_' + str(epoch) + '_ValAcc_' +
                         str(round(plot_data['val_prec50'][epoch], 2)))
        train.save_checkpoint(model, filename, prefix_len)

    if plot:
Esempio n. 6
0
    plot_data = train.train(train_loader, model, criterion, optimizer, epoch,
                            print_freq, plot_data)

    # Evaluate on validation set
    plot_data = train.validate(val_loader, model, criterion, epoch, print_freq,
                               plot_data)

    # Remember best model and save checkpoint
    is_best = plot_data['val_loss'][epoch] < best_loss
    if is_best:
        print("New best model by loss. Val Loss = " +
              str(plot_data['val_loss'][epoch]))
        best_loss = plot_data['val_loss'][epoch]
        filename = dataset + '/models/' + training_id + '_epoch_' + str(
            epoch) + '_ValLoss_' + str(round(plot_data['val_loss'][epoch], 4))
        train.save_checkpoint(model, filename)

    if plot:

        ax1.plot(it_axes[0:epoch + 1], plot_data['train_loss'][0:epoch + 1],
                 'r')
        ax2.plot(it_axes[0:epoch + 1],
                 plot_data['train_correct_triplets'][0:epoch + 1], 'b')

        ax1.plot(it_axes[0:epoch + 1], plot_data['val_loss'][0:epoch + 1], 'y')
        ax2.plot(it_axes[0:epoch + 1],
                 plot_data['val_correct_triplets'][0:epoch + 1], 'g')

        plt.title(training_id)
        plt.grid(True)
        # plt.ion()
Esempio n. 7
0
def main():
    # Record the best epoch and accuracy
    best_result = {'epoch': 1, 'accuracy': 0.}

    args = parse_args()
    # Use model name to name env's
    args.env = args.model
    vis = Visualize(env=args.env) if not args.close_visdom else None

    # Create file to storage result and checkpoint
    if args.root_path != '':
        args.result_path = os.path.join(args.root_path, args.result_path)
        args.checkpoint_path = os.path.join(args.root_path,
                                            args.checkpoint_path)
        args.pretrained_models_path = os.path.join(args.root_path,
                                                   args.pretrained_models_path)
        if not os.path.exists(args.result_path):
            os.mkdir(args.result_path)
        if not os.path.exists(args.checkpoint_path):
            os.mkdir(args.checkpoint_path)
        if not os.path.exists(args.pretrained_models_path):
            os.mkdir(args.pretrained_models_path)
        if args.resume_path:
            args.resume_path = os.path.join(args.checkpoint_path,
                                            args.resume_path)

    # Set manual seed to reproduce random value
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    # Create model
    if not torch.cuda.is_available():
        args.device = torch.device('cpu')
    else:
        args.device = torch.device(args.device)
    model = get_model(args)
    model.to(args.device)
    print(model)

    # Define loss function 、 optimizer and scheduler to adjust lr
    criterion = nn.CrossEntropyLoss().to(args.device)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step_size,
                                          gamma=args.gamma)
    # optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Continue training from checkpoint epoch with checkpoint parameters
    if args.resume_path:
        if os.path.isfile(args.resume_path):
            print("=> loading checkpoint '{}'...".format(args.resume_path))
            checkpoint = torch.load(args.resume_path)
            args.begin_epoch = checkpoint['epoch'] + 1
            best_result = checkpoint['best_result']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume_path))

    # Load dataset
    train_loader = data_loader(args, train=True)
    val_loader = data_loader(args, val=True)
    test_loader = data_loader(args, test=True)
    # Begin to train
    since = time.time()
    for epoch in range(args.begin_epoch, args.epochs + 1):
        adjust_learning_rate(scheduler)
        train_accuracy, train_loss = train_model(args, epoch, model,
                                                 train_loader, criterion,
                                                 optimizer, scheduler, vis)
        # Verify accuracy and loss after training
        val_accuracy, val_loss = val_model(args, epoch, best_result, model,
                                           val_loader, criterion, vis)

        # Plot train and val's accuracy and loss each epoch
        accuracy = [[train_accuracy], [val_accuracy]]
        loss = [[train_loss], [val_loss]]
        vis.plot2('accuracy', accuracy, ['train', 'val'])
        vis.plot2('loss', loss, ['train', 'val'])

        # Save checkpoint model each checkpoint interval and keep the last one
        if epoch % args.checkpoint_interval == 0 or epoch == args.epochs:
            save_checkpoint(args, epoch, best_result, model, optimizer,
                            scheduler)
    # Total time to train
    time_elapsed = time.time() - since
    print('Training complete in {}m {}s'.format(time_elapsed // 60,
                                                time_elapsed % 60))

    # Test model with the best val model parameters
    best_model_path = os.path.join(
        args.result_path, '{}_{}.pth'.format(args.model, best_result['epoch']))
    print("Using '{}' for test...".format(best_model_path))
    model.load_state_dict(torch.load(best_model_path))
    test_model(args, model, test_loader)