Example #1
0
def main():
    model = Net()

    # Part I - Train model to localize spaceship on images containing spaceship
    print("Start localization training")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.Adam(model.parameters(), eps=1e-07)

    cudnn.benchmark = True
    criterion = cal_diou

    epochs = 40
    steps_per_epoch = 3125
    batch_size = 64

    for epoch in range(0, epochs):
        adjust_learning_rate(optimizer, epoch)
        train(model, optimizer, epoch, device, steps_per_epoch, batch_size,
              criterion)

    # Part II - Apply transfer learning to train pre-trained model to detect whether spaceship exists
    print("Start classification training")

    model.mode = 'classification'
    criterion = nn.BCELoss()

    for param in model.convnet.parameters():
        param.requires_grad = False

    for param in model.localizer.parameters():
        param.requires_grad = False

    batch_size = 64
    steps_per_epoch = 500
    epochs = 2

    optimizer = optim.Adam(model.parameters(), eps=1e-07)

    for epoch in range(epochs):
        train(model,
              optimizer,
              epoch,
              device,
              steps_per_epoch,
              batch_size,
              criterion,
              classification=True)

    # Save model
    path = F'model.pth.tar'
    torch.save(model.state_dict(), path)
Example #2
0
            save.save_model_w_condition(model=vgg,
                                        model_dir=model_dir,
                                        model_name='best_model_protos_opt',
                                        accu=acc,
                                        target_accu=best_acc,
                                        log=log)

            is_best = acc > best_acc
            best_acc = max(acc, best_acc)
            if is_best:
                best_epoch = epoch

        if (epoch + 1) % args.decay == 0:
            log('lowered lrs by factor of 10')
            adjust_learning_rate(optimizers)

    best_acc1 = best_acc
    best_epoch1 = best_epoch

    log("optimize joint")
    for epoch in range(joint_opt):

        log('epoch: \t{0}'.format(epoch))

        # layer = getattr(vgg,"root_layer")
        # weights = [p.data for p in layer.parameters()][0]
        # weights = np.array([[np.round(weight.item(),2) for weight in beta] for beta in weights])
        # print("root weights")
        # print(weights)
Example #3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    # if args.pretrained:
    #     print("=> using pre-trained model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch](pretrained=True)
    #     model = autofit(model, args.arch, args.num_classes)
    # else:
    #     print("=> creating model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch](num_classes=args.num_classes)
    model = AutoFitNet(arch=args.arch,
                       pretrained=args.pretrained,
                       num_classes=args.num_classes)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    testdir = os.path.join(args.data, 'test')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.Resize(256),
    #         transforms.RandomResizedCrop(224),
    #         # transforms.RandomHorizontalFlip(),
    #         transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]))
    train_dataset = CityFuncDataset(
        traindir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.1,
                                   contrast=0.1,
                                   saturation=0.1,
                                   hue=0.1),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(CityFuncDataset(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.test:
        test_loader = torch.utils.data.DataLoader(CityFuncDataset(
            testdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)
        validate(test_loader, model, criterion, args)
        return

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    epoch_time = AverageMeter('Time', ':6.3f', 's')
    end = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        # learning rate decay
        adjust_learning_rate(optimizer, epoch, args.lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)

        # measure elapsed time
        epoch_time.update(time.time() - end)
        eta = (args.epochs - epoch - 1) * epoch_time.avg
        eta_str = str(datetime.timedelta(seconds=int(eta)))
        print(
            'Epoch: [{epoch:d}]\tTime:{time:6.3f}s\tETA:{eta:6.3f}s ({eta_str:s})'
            .format(epoch=epoch, time=epoch_time.val, eta=eta,
                    eta_str=eta_str))
        end = time.time()
    OOD_testloader = utilsdata.DataLoader(
        custom_dset.Dataset_fromPythonList(OOD_testlist,
                                           transform=data_transform),
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        timeout=1000,
    )

    #################################################################################################################
    # Training Loop
    final_test_acc = 0.
    for epoch in range(num_epochs):

        # Decay learning rate according to decay schedule
        helpers.adjust_learning_rate(optimizer, epoch, learning_rate_table)
        print("Starting Epoch {}/{}. lr = {}".format(
            epoch, num_epochs, learning_rate_table[epoch]))
        # Train
        train_acc, train_loss, percent_real = training_helpers.train_model(
            net, optimizer, ID_trainloader, 10 - NUM_HOLDOUT_CLASSES)
        print(
            "[{}] Epoch [ {} / {} ]; lr: {} TrainAccuracy: {:.5f} TrainLoss: {:.5f} %-Real: {}"
            .format(ITER, epoch, num_epochs, learning_rate_table[epoch],
                    train_acc, train_loss, percent_real))
        # Test
        test_acc, test_loss = helpers.test_model(net, device, ID_testloader,
                                                 MEAN, STD)
        print(
            "\t[{}] Epoch [ {} / {} ]; TestAccuracy: {:.5f} TestLoss: {:.5f}".
            format(ITER, epoch, num_epochs, test_acc, test_loss))
def train_and_validation(args):
    """Initialize generator, discriminator, memory_network and run the train and validation process."""
    if args.use_memory == True:
        global_step, device, data_loaders, mem, feature_integrator, generator, discriminator, optimizers, losses = init_training(args)
    else:
        global_step, device, data_loaders, generator, discriminator, optimizers, losses = init_training(args)
    #  run training process
    for epoch in range(args.start_epoch + 1, args.end_epoch + 1):
        print('\n========== EPOCH {} =========='.format(epoch))

        for phase in ['train', 'val']:

            # running losses for generator
            epoch_gen_adv_loss = 0.0
            epoch_gen_l1_loss = 0.0

            # running losses for discriminator
            epoch_disc_real_loss = 0.0
            epoch_disc_fake_loss = 0.0
            epoch_disc_real_acc = 0.0
            epoch_disc_fake_acc = 0.0

            if phase == 'train':
                print('TRAINING:')
            else:
                print('VALIDATION:')

            for idx, batch in enumerate(data_loaders[phase]):

                res_input = batch['res_input'].to(device)
                color_feat = batch['color_feat'].to(device)
                img_l = (batch['img_l'] / 100.0).to(device)
                img_ab = (batch['img_ab'] / 110.0).to(device)
                img_id = batch['img_id'].to(device)
                real_img_lab = torch.cat([img_l, img_ab], dim=1).to(device)

                # generate targets
                target_ones = torch.ones(img_l.size(0), 1).to(device)
                target_zeros = torch.zeros(img_l.size(0), 1).to(device)

                if phase == 'train':
                    # adjust LR
                    global_step += 1
                    adjust_learning_rate(optimizers['gen'], global_step, base_lr=args.base_lr_gen,
                                         lr_decay_rate=args.lr_decay_rate, lr_decay_steps=args.lr_decay_steps)
                    adjust_learning_rate(optimizers['disc'], global_step, base_lr=args.base_lr_disc,
                                         lr_decay_rate=args.lr_decay_rate, lr_decay_steps=args.lr_decay_steps)
                    if args.use_memory == True:
                        adjust_learning_rate(optimizers['mem'], global_step, base_lr=args.base_lr_mem,
                                             lr_decay_rate=args.lr_decay_rate, lr_decay_steps=args.lr_decay_steps)
                        adjust_learning_rate(optimizers['feat'], global_step, base_lr=args.base_lr_feat,
                                             lr_decay_rate=args.lr_decay_rate, lr_decay_steps=args.lr_decay_steps)

                if args.use_memory == True:
                    ### 1) Train spatial feature extractor
                    if phase == 'train':
                        optimizers['mem'].zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        res_feature = mem(res_input)

                        if phase == 'train':
                            mem_loss = mem.unsupervised_loss(res_feature, color_feat, args.color_thres)
                            mem_loss.backward()
                            optimizers['mem'].step()

                    ### 2) Update Memory module
                    if phase == 'train':
                        with torch.no_grad():
                            res_feature = mem(res_input)
                            mem.memory_update(res_feature, color_feat, args.color_thres, img_id)

                    ### 3) Train Feature_Integrator
                    if args.use_feat_integrator == True:
                        if phase == 'train':
                            optimizers['feat'].zero_grad()

                        with torch.set_grad_enabled(phase == 'train'):
                            top_features, ref_img_ids = mem.topk_feature(res_feature, 3)
                            top_features = torch.transpose(top_features, dim0 = 1, dim1 = 2)
                            combined_features = feature_integrator(top_features)
                            feat_loss = losses['KLD'](combined_features, color_feat)

                            if phase == 'train':
                                feat_loss.backward()
                                optimizers['feat'].step()

                    if phase == 'val':
                        top1_feature, ref_img_ids = mem.topk_feature(res_feature, 3)
                        color_feat = top1_feature[:, 0, :]
                        if args.use_feat_integrator == True:
                            color_feat = combined_features

                ### 3) Train Generator
                if phase == 'train':
                    optimizers['gen'].zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    dis_color_feat = torch.cat([torch.unsqueeze(color_feat, 2) for _ in range(args.img_size)], dim = 2)
                    dis_color_feat = torch.cat([torch.unsqueeze(dis_color_feat, 3) for _ in range(args.img_size)], dim = 3)
                    fake_img_ab = generator(img_l, color_feat)
                    fake = discriminator(fake_img_ab, img_l, dis_color_feat)
                    fake_img_lab = torch.cat([img_l, fake_img_ab], dim=1).to(device)

                    g_loss_GAN = losses['disc'](fake, target_ones)
                    g_loss_L1 = losses['l1'](fake_img_ab, img_ab)
                    g_loss = (1.0 - args.l1_weight) * g_loss_GAN + (args.l1_weight * g_loss_L1)

                    if phase == 'train':
                        g_loss.backward()
                        optimizers['gen'].step()

                epoch_gen_adv_loss += g_loss_GAN.item()
                epoch_gen_l1_loss += g_loss_L1.item()

                ### 4) Train Discriminator
                if phase == 'train':
                    optimizers['disc'].zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    prediction_real = discriminator(img_ab, img_l, dis_color_feat)
                    prediction_fake = discriminator(fake_img_ab.detach(), img_l, dis_color_feat)

                    d_loss_real = losses['disc'](prediction_real, target_ones * args.smoothing)
                    d_loss_fake = losses['disc'](prediction_fake, target_zeros)
                    d_loss = d_loss_real + d_loss_fake

                    if phase == 'train':
                        d_loss.backward()
                        optimizers['disc'].step()

                epoch_disc_real_loss += d_loss_real.item()
                epoch_disc_fake_loss += d_loss_fake.item()
                epoch_disc_real_acc += np.mean(prediction_real.detach().cpu().numpy() > 0.5)
                epoch_disc_fake_acc += np.mean(prediction_fake.detach().cpu().numpy() <= 0.5)

                # save the first sample for later
                if phase == 'val' and idx == 0:
                    sample_real_img_lab = real_img_lab
                    sample_fake_img_lab = fake_img_lab
                    sample_ref_img_ids = ref_img_ids

            # display losses
            print_losses(epoch_gen_adv_loss, epoch_gen_l1_loss,
                         epoch_disc_real_loss, epoch_disc_fake_loss,
                         epoch_disc_real_acc, epoch_disc_fake_acc,
                         len(data_loaders[phase]), args.l1_weight)

            if phase == 'val':
                if epoch % args.save_freq == 0 or epoch == args.end_epoch - 1:
                    gen_path = os.path.join(args.save_path, 'checkpoint_ep{}_gen.pt'.format(epoch))
                    disc_path = os.path.join(args.save_path, 'checkpoint_ep{}_disc.pt'.format(epoch))
                    if args.use_memory == True:
                        mem_path = os.path.join(args.save_path, 'checkpoint_ep{}_mem.pt'.format(epoch))
                        if args.use_feat_integrator == True:
                            feat_path = os.path.join(args.save_path, 'checkpoint_ep{}_feat.pt'.format(epoch))
                    torch.save(generator.state_dict(), gen_path)
                    torch.save(discriminator.state_dict(), disc_path)
                    if args.use_memory == True:
                        torch.save({'mem_model' : mem.state_dict(),
                                     'mem_key' : mem.spatial_key.cpu(),
                                     'mem_value' : mem.color_value.cpu(),
                                     'mem_age' : mem.age.cpu(),
                                     'img_id' : mem.img_id.cpu()}, mem_path)
                        if args.use_feat_integrator == True:
                            torch.save(feature_integrator.state_dict(), feat_path)
                    print('Checkpoint.')

                # display sample images
                save_sample(
                    sample_real_img_lab,
                    sample_fake_img_lab,
                    sample_ref_img_ids,
                    args.img_size,
                    os.path.join(args.save_path, 'sample_ep{}.png'.format(epoch)),
                    os.path.join(args.save_path, 'ref_ep{}.png'.format(epoch))
                )
Example #6
0
def run_training(args):
    """Initialize and run the training process."""
    global_step, device, data_loaders, generator, discriminator, optimizers, losses = init_training(
        args)
    #  run training process
    for epoch in range(args.start_epoch, args.max_epoch):
        print('\n========== EPOCH {} =========='.format(epoch))

        for phase in ['train', 'test']:

            # running losses for generator
            epoch_gen_adv_loss = 0.0
            epoch_gen_l1_loss = 0.0

            # running losses for discriminator
            epoch_disc_real_loss = 0.0
            epoch_disc_fake_loss = 0.0
            epoch_disc_real_acc = 0.0
            epoch_disc_fake_acc = 0.0

            if phase == 'train':
                print('TRAINING:')
            else:
                print('VALIDATION:')

            for idx, sample in enumerate(data_loaders[phase]):

                # get data
                img_l, real_img_lab = sample[:, 0:1, :, :].float().to(
                    device), sample.float().to(device)

                # generate targets
                target_ones = torch.ones(real_img_lab.size(0), 1).to(device)
                target_zeros = torch.zeros(real_img_lab.size(0), 1).to(device)

                if phase == 'train':
                    # adjust LR
                    global_step += 1
                    adjust_learning_rate(optimizers['gen'],
                                         global_step,
                                         base_lr=args.base_lr_gen,
                                         lr_decay_rate=args.lr_decay_rate,
                                         lr_decay_steps=args.lr_decay_steps)
                    adjust_learning_rate(optimizers['disc'],
                                         global_step,
                                         base_lr=args.base_lr_disc,
                                         lr_decay_rate=args.lr_decay_rate,
                                         lr_decay_steps=args.lr_decay_steps)

                    # reset generator gradients
                    optimizers['gen'].zero_grad()

                # train / inference the generator
                with torch.set_grad_enabled(phase == 'train'):
                    fake_img_ab = generator(img_l)
                    fake_img_lab = torch.cat([img_l, fake_img_ab],
                                             dim=1).to(device)

                    # adv loss
                    adv_loss = losses['disc'](discriminator(fake_img_lab),
                                              target_ones)
                    # l1 loss
                    l1_loss = losses['l1'](real_img_lab[:, 1:, :, :],
                                           fake_img_ab)
                    # full gen loss
                    full_gen_loss = (1.0 - args.l1_weight) * adv_loss + (
                        args.l1_weight * l1_loss)

                    if phase == 'train':
                        full_gen_loss.backward()
                        optimizers['gen'].step()

                epoch_gen_adv_loss += adv_loss.item()
                epoch_gen_l1_loss += l1_loss.item()

                if phase == 'train':
                    # reset discriminator gradients
                    optimizers['disc'].zero_grad()

                # train / inference the discriminator
                with torch.set_grad_enabled(phase == 'train'):
                    prediction_real = discriminator(real_img_lab)
                    prediction_fake = discriminator(fake_img_lab.detach())

                    loss_real = losses['disc'](prediction_real,
                                               target_ones * args.smoothing)
                    loss_fake = losses['disc'](prediction_fake, target_zeros)
                    full_disc_loss = loss_real + loss_fake

                    if phase == 'train':
                        full_disc_loss.backward()
                        optimizers['disc'].step()

                epoch_disc_real_loss += loss_real.item()
                epoch_disc_fake_loss += loss_fake.item()
                epoch_disc_real_acc += np.mean(
                    prediction_real.detach().cpu().numpy() > 0.5)
                epoch_disc_fake_acc += np.mean(
                    prediction_fake.detach().cpu().numpy() <= 0.5)

                # save the first sample for later
                if phase == 'test' and idx == 0:
                    sample_real_img_lab = real_img_lab
                    sample_fake_img_lab = fake_img_lab

            # display losses
            print_losses(epoch_gen_adv_loss, epoch_gen_l1_loss,
                         epoch_disc_real_loss, epoch_disc_fake_loss,
                         epoch_disc_real_acc, epoch_disc_fake_acc,
                         len(data_loaders[phase]), args.l1_weight)

            # save after every nth epoch
            if phase == 'test':
                if epoch % args.save_freq == 0 or epoch == args.max_epoch - 1:
                    gen_path = os.path.join(
                        args.save_path, 'checkpoint_ep{}_gen.pt'.format(epoch))
                    disc_path = os.path.join(
                        args.save_path,
                        'checkpoint_ep{}_disc.pt'.format(epoch))
                    torch.save(generator.state_dict(), gen_path)
                    torch.save(discriminator.state_dict(), disc_path)
                    print('Checkpoint.')

                # display sample images
                save_sample(
                    sample_real_img_lab, sample_fake_img_lab,
                    os.path.join(args.save_path,
                                 'sample_ep{}.png'.format(epoch)))