Example #1
0
def main(args):
    print('Loading data')
    idxs = np.load(args.boards_file, allow_pickle=True)['idxs']
    print(f'Number of Boards: {len(idxs)}')

    if torch.cuda.is_available() and args.num_gpus > 0:
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if args.shuffle:
        np.random.shuffle(idxs)

    train_idxs = idxs[:-args.num_test]
    test_idxs = idxs[-args.num_test:]

    train_loader = DataLoader(Boards(train_idxs),
                              batch_size=args.batch_size,
                              shuffle=False)
    test_loader = DataLoader(Boards(test_idxs), batch_size=args.batch_size)

    model = AutoEncoder().to(device)
    if args.model_loadname:
        model.load_state_dict(torch.load(args.model_loadname))

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    model.train()
    losses = []
    total_iters = 0

    for epoch in range(args.init_epoch, args.epochs):
        print(f'Running epoch {epoch} / {args.epochs}\n')
        for batch_idx, board in tqdm(enumerate(train_loader),
                                     total=len(train_loader)):
            board = board.to(device)
            optimizer.zero_grad()
            loss = model.loss(board)
            loss.backward()

            losses.append(loss.item())
            optimizer.step()

            if total_iters % args.log_interval == 0:
                tqdm.write(f'Loss: {loss.item()}')

            if total_iters % args.save_interval == 0:
                torch.save(
                    model.state_dict(),
                    append_to_modelname(args.model_savename, total_iters))
                plot_losses(losses, 'vis/ae_losses.png')
            total_iters += 1
def main():
    with open("config.json") as json_file:
        conf = json.load(json_file)
    dataset_path = os.path.join(conf['data']['dataset_path'],
                                conf['data']['dataset_file'])
    device = conf['train']['device']

    model = AutoEncoder(in_channels=1,
                        dec_channels=1,
                        latent_size=conf['model']['latent_size'])
    model = model.to(device)
    model.load_state_dict(torch.load(load_path))

    dspites_dataset = Dspites(dataset_path)
    train_val = train_val_split(dspites_dataset)
    val_test = train_val_split(train_val['val'], val_split=0.2)

    data_loader_train = DataLoader(train_val['train'],
                                   batch_size=conf['train']['batch_size'],
                                   shuffle=True,
                                   num_workers=2)
    data_loader_val = DataLoader(val_test['val'],
                                 batch_size=200,
                                 shuffle=False,
                                 num_workers=1)
    data_loader_test = DataLoader(val_test['train'],
                                  batch_size=200,
                                  shuffle=False,
                                  num_workers=1)

    print('autoencoder training')
    print('frozen encoder: ', freeze_encoder)
    print('train dataset length: ', len(train_val['train']))
    print('val dataset length: ', len(val_test['val']))
    print('test dataset length: ', len(val_test['train']))

    print('latent space size:', conf['model']['latent_size'])
    print('batch size:', conf['train']['batch_size'])

    loss_function = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.train()
    if freeze_encoder:
        model.freeze_encoder()

    for epoch in range(25):
        if epoch > 15:
            for param in optimizer.param_groups:
                param['lr'] = max(0.00001,
                                  param['lr'] / conf['train']['lr_decay'])
                print('lr: ', param['lr'])

        loss_list = []
        model.train()

        for batch_i, batch in enumerate(data_loader_train):
            augment_transform = np.random.choice(augment_transform_list1)
            batch1 = image_batch_transformation(batch, augment_transform)
            loss = autoencoder_step(model, batch, device, loss_function)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        mean_epoch_loss = sum(loss_list) / len(loss_list)
        model.eval()
        validation_loss = autoencoder_validation(data_loader_val, model,
                                                 device, loss_function)
        if epoch == 0:
            min_validation_loss = validation_loss
        else:
            min_validation_loss = min(min_validation_loss, validation_loss)
        print('epoch {0}, loss: {1:2.5f}, validation: {2:2.5f}'.format(
            epoch, mean_epoch_loss, validation_loss))
        if min_validation_loss == validation_loss:
            #pass
            torch.save(model.state_dict(), save_path)

    model.load_state_dict(torch.load(save_path))
    test_results = autoencoder_validation(data_loader_test, model, device,
                                          loss_function)
    print('test result: ', test_results)
config = Config()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

autoencoder = AutoEncoder(config)
siamese_network = SiameseNetwork(config)

autoencoder_file = '/autoencoder_epoch175_loss1.1991.pth'
siamese_file = '/siamese_network_epoch175_loss1.1991.pth'

if config.load_model:
    autoencoder.load_state_dict(torch.load(config.saved_models_folder + autoencoder_file))
    siamese_network.load_state_dict(torch.load(config.saved_models_folder + siamese_file))

autoencoder.to(device)
autoencoder.train()

siamese_network.to(device)
siamese_network.train()

params = list(autoencoder.parameters()) + list(siamese_network.parameters())

optimizer = torch.optim.Adam(params, lr=config.lr, betas=(0.9, 0.999))

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    # transforms.RandomCrop(size=128),
    # transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
])
train_data = torchvision.datasets.ImageFolder(config.data_folder, transform=transform)
Example #4
0
    return mean_loss


#model.freeze_encoder()
for epoch in range(35):
    if epoch > 15:
        for param in optimizer.param_groups:
            param['lr'] = max(0.00003, param['lr'] / conf['train']['lr_decay'])
            print('lr: ', param['lr'])

    loss_list = []
    emb_loss_list = []
    delta_loss_list = []
    reconstr_loss_list = []

    model.train()

    for batch_i, batch in enumerate(data_loader_train):
        loss, emb_loss, delta_coeff, reconstr_loss = autoencoder_step(
            model, batch, device, loss_function)
        loss_list.append(loss.item())
        emb_loss_list.append(emb_loss.item())
        delta_loss_list.append(delta_coeff.item())
        reconstr_loss_list.append(reconstr_loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    mean_epoch_loss = sum(loss_list) / len(loss_list)
    mean_ebedding_loss = sum(emb_loss_list) / len(emb_loss_list)
    mean_delta_loss = sum(delta_loss_list) / len(delta_loss_list)
Example #5
0
def main():
    opts = get_argparser().parse_args()

    # dataset
    train_trainsform = transforms.Compose([
        transforms.RandomCrop(size=512, pad_if_needed=True),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
    ])

    val_transform = transforms.Compose([transforms.ToTensor()])

    train_loader = data.DataLoader(data.ConcatDataset([
        ImageDataset(root='datasets/data/CLIC/train',
                     transform=train_trainsform),
        ImageDataset(root='datasets/data/CLIC/valid',
                     transform=train_trainsform),
    ]),
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)

    val_loader = data.DataLoader(ImageDataset(root='datasets/data/kodak',
                                              transform=val_transform),
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Train set: %d, Val set: %d" %
          (len(train_loader.dataset), len(val_loader.dataset)))
    model = AutoEncoder(C=128, M=128, in_chan=3, out_chan=3).to(device)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 weight_decay=1e-5)

    # checkpoint
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        model.load_state_dict(torch.load(opts.ckpt))
    else:
        print("[!] Retrain")

    if opts.loss_type == 'ssim':
        criterion = SSIM_Loss(data_range=1.0, size_average=True, channel=3)
    else:
        criterion = MS_SSIM_Loss(data_range=1.0,
                                 size_average=True,
                                 channel=3,
                                 nonnegative_ssim=True)

    #==========   Train Loop   ==========#
    for cur_epoch in range(opts.total_epochs):
        # =====  Train  =====
        model.train()
        for cur_step, images in enumerate(train_loader):
            images = images.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, images)
            loss.backward()

            optimizer.step()

            if (cur_step) % opts.log_interval == 0:
                print("Epoch %d, Batch %d/%d, loss=%.6f" %
                      (cur_epoch, cur_step, len(train_loader), loss.item()))

        # =====  Save Latest Model  =====
        torch.save(model.state_dict(), 'latest_model.pt')

        # =====  Validation  =====
        print("Val on Kodak dataset...")
        best_score = 0.0
        cur_score = test(opts, model, val_loader, criterion, device)
        print("%s = %.6f" % (opts.loss_type, cur_score))
        # =====  Save Best Model  =====
        if cur_score > best_score:  # save best model
            best_score = cur_score
            torch.save(model.state_dict(), 'best_model.pt')
            print("Best model saved as best_model.pt")
Example #6
0
def main():
    loss_function = nn.BCELoss()

    with open("config.json") as json_file:
        conf = json.load(json_file)
    device = conf['train']['device']

    dataset_path = os.path.join(conf['data']['dataset_path'],
                                conf['data']['dataset_file'])
    dspites_dataset = Dspites(dataset_path)
    train_val = train_val_split(dspites_dataset)
    val_test = train_val_split(train_val['val'], val_split=0.2)

    data_loader_train = DataLoader(train_val['train'],
                                   batch_size=conf['train']['batch_size'],
                                   shuffle=True,
                                   num_workers=2)
    data_loader_val = DataLoader(val_test['val'],
                                 batch_size=200,
                                 shuffle=False,
                                 num_workers=1)
    data_loader_test = DataLoader(val_test['train'],
                                  batch_size=200,
                                  shuffle=False,
                                  num_workers=1)

    print('metric learning')
    print('train dataset length: ', len(train_val['train']))
    print('val dataset length: ', len(val_test['val']))
    print('test dataset length: ', len(val_test['train']))

    print('latent space size:', conf['model']['latent_size'])
    print('batch size:', conf['train']['batch_size'])
    print('margin:', conf['train']['margin'])

    loss_list = []
    model = AutoEncoder(in_channels=1,
                        dec_channels=1,
                        latent_size=conf['model']['latent_size'])
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=conf['train']['lr'])

    model.train()
    if load_path:
        model.load_state_dict(torch.load(load_path))

    for epoch in range(10):
        for param in optimizer.param_groups:
            param['lr'] = max(0.00001, param['lr'] / conf['train']['lr_decay'])
            print('lr: ', param['lr'])
        loss_list = []

        for batch_i, batch in enumerate(data_loader_train):
            # if batch_i == 1000:
            #     break
            batch = batch['image']
            batch = batch.type(torch.FloatTensor)
            batch = batch.to(device)
            loss = triplet_step(model, batch, transform1, transform2)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        recall, recall10 = recall_validation(model, data_loader_val,
                                             transform1, transform2, device)
        if epoch == 0:
            min_validation_recall = recall
        else:
            min_validation_recall = min(min_validation_recall, recall)
        if min_validation_recall == recall and save_path:
            torch.save(model.state_dict(), save_path)
        print('epoch {0}, loss {1:2.4f}'.format(
            epoch,
            sum(loss_list) / len(loss_list)))
        print('recall@3: {0:2.4f}, recall 10%: {1:2.4f}'.format(
            recall, recall10))

    model.load_state_dict(torch.load(save_path))
    recall, recall10 = recall_validation(model, data_loader_test, transform1,
                                         transform2)
    print('test recall@3: {0:2.4f}, recall@3 10%: {1:2.4f}'.format(
        recall, recall10))
Example #7
0
class Trainer(object):
    def __init__(self, args):
        # load network
        self.G = AutoEncoder(args)
        self.D = Discriminator(args)
        self.G.weight_init()
        self.D.weight_init()
        self.G.cuda()
        self.D.cuda()
        self.criterion = nn.MSELoss()

        # load data
        self.train_dataset = CUBDataset(split='train')
        self.valid_dataset = CUBDataset(split='val')
        self.train_loader = DataLoader(dataset=self.train_dataset, batch_size=args.batch_size)
        self.valid_loader = DataLoader(dataset=self.valid_dataset, batch_size=args.batch_size)

        # Optimizers
        self.G_optim = optim.Adam(self.G.parameters(), lr = args.lr_G)
        self.D_optim = optim.Adam(self.D.parameters(), lr = 0.5 * args.lr_D)
        self.G_scheduler = StepLR(self.G_optim, step_size=30, gamma=0.5)
        self.D_scheduler = StepLR(self.D_optim, step_size=30, gamma=0.5)

        # Parameters
        self.epochs = args.epochs
        self.batch_size = args.batch_size
        self.z_var = args.z_var 
        self.sigma = args.sigma
        self.lambda_1 = args.lambda_1
        self.lambda_2 = args.lambda_2

        log_dir = os.path.join(args.log_dir, datetime.now().strftime("%m_%d_%H_%M_%S"))
        # if not os.path.isdir(log_dir):
            # os.makedirs(log_dir)
        self.writter = SummaryWriter(log_dir)

    def train(self):
        global_step = 0
        self.G.train()
        self.D.train()
        ones = Variable(torch.ones(self.batch_size, 1).cuda())
        zeros = Variable(torch.zeros(self.batch_size, 1).cuda())

        for epoch in range(self.epochs):
            self.G_scheduler.step()
            self.D_scheduler.step()
            print("training epoch {}".format(epoch))
            all_num = 0.0
            acc_num = 0.0
            images_index = 0
            for data in tqdm(self.train_loader):
                images = Variable(data['image64'].cuda())
                target_image = Variable(data['image64'].cuda()) 
                target = Variable(data['class_id'].cuda())
                recon_x, z_tilde, output = self.G(images)
                z = Variable((self.sigma*torch.randn(z_tilde.size())).cuda())
                log_p_z = log_density_igaussian(z, self.z_var).view(-1, 1)
                ones = Variable(torch.ones(images.size()[0], 1).cuda())
                zeros = Variable(torch.zeros(images.size()[0], 1).cuda())

                # ======== Train Discriminator ======== #
                D_z = self.D(z)
                D_z_tilde = self.D(z_tilde)
                D_loss = F.binary_cross_entropy_with_logits(D_z+log_p_z, ones) + \
                    F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, zeros)

                total_D_loss = self.lambda_1*D_loss
                self.D_optim.zero_grad()
                total_D_loss.backward(retain_graph=True)
                self.D_optim.step()

                # ======== Train Generator ======== #
                recon_loss = F.mse_loss(recon_x, target_image, reduction='sum').div(self.batch_size)
                G_loss = F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, ones)
                class_loss = F.cross_entropy(output, target)
                total_G_loss = recon_loss + self.lambda_1*G_loss + self.lambda_2*class_loss
                self.G_optim.zero_grad()
                total_G_loss.backward()
                self.G_optim.step()

                # ======== Compute Classification Accuracy ======== #
                values, indices = torch.max(output, 1)
                acc_num += torch.sum((indices == target)).cpu().item()
                all_num += len(target)
                
                # ======== Log by TensorBoardX
                global_step += 1
                if (global_step + 1) % 10 == 0:
                    self.writter.add_scalar('train/recon_loss', recon_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/G_loss', G_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/D_loss', D_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/classify_loss', class_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/total_G_loss', total_G_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/acc', acc_num/all_num, global_step)
                    if images_index < 5 and torch.rand(1) < 0.5:
                        self.writter.add_image('train_output_{}'.format(images_index), recon_x[0], global_step)
                        self.writter.add_image('train_target_{}'.format(images_index), target_image[0], global_step)
                        images_index += 1
            if epoch % 2 == 0:
                self.validate(global_step)

    def validate(self, global_step):
        self.G.eval()
        self.D.eval()
        acc_num = 0.0
        all_num = 0.0
        recon_loss = 0.0
        images_index = 0
        for data in tqdm(self.valid_loader):
            images = Variable(data['image64'].cuda())
            target_image = Variable(data['image64'].cuda()) 
            target = Variable(data['class_id'].cuda())
            recon_x, z_tilde, output = self.G(images)
            values, indices = torch.max(output, 1)
            acc_num += torch.sum((indices == target)).cpu().item()
            all_num += len(target)
            recon_loss += F.mse_loss(recon_x, target_image, reduction='sum').cpu().item()
            if images_index < 5:
                self.writter.add_image('valid_output_{}'.format(images_index), recon_x[0], global_step)
                self.writter.add_image('valid_target_{}'.format(images_index), target_image[0], global_step)
                images_index += 1

        self.writter.add_scalar('valid/acc', acc_num/all_num, global_step)
        self.writter.add_scalar('valid/recon_loss', recon_loss/all_num, global_step)