Example #1
0
def main():
    if FLAGS.dataset == 'mnist':
        X_train = load_data.load_mnist()
        c_dim = 1
    elif FLAGS.dataset == 'cifar10':
        X_train, y_train, X_test, y_tes = load_data.load_cifar10(
            FLAGS.dataset_dir)
        c_dim = 3
    model = WasserstainDCGAN(FLAGS, c_dim=c_dim)
    model.construct_network()
    model.fit(X_train)
Example #2
0
def test(net, log=None, batch_size=128, data='cifar100'):
    """
    Test on trained model.
    :param net: model to be tested
    :param log: log dir
    :param batch_size: batch size
    :param data: datasets used
    """

    net.eval()
    is_train = False

    # data
    if data == 'cifar10':
        test_loader = load_cifar10(is_train, batch_size)
    elif data == 'cifar100':
        test_loader = load_cifar100(is_train, batch_size)
    elif data == 'svhn':
        test_loader = load_svhn(is_train, batch_size)
    elif data == 'mnist':
        test_loader = load_mnist(is_train, batch_size)
    elif data == 'tinyimagenet':
        test_loader = load_tiny_imagenet(is_train, batch_size)
    else:
        exit()

    correct = 0
    total = 0
    inference_start = time.time()
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            inputs, labels = data
            outputs, outputs_conv = net(inputs.cuda())
            _, predicted = torch.max(F.softmax(outputs, -1), 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum()
    inference_time = time.time() - inference_start
    print('Accuracy: %f %%; Inference time: %fs' %
          (float(100) * float(correct) / float(total), inference_time))
    # print('.', end='')

    if log != None:
        log.write('Accuracy of the network on the 10000 test images: %f %%\n' %
                  (float(100) * float(correct) / float(total)))
        log.write('Inference time is: %fs\n' % inference_time)

    return inference_time
Example #3
0
def train(net,
          lr,
          log=None,
          optimizer_option='SGD',
          data='cifar100',
          epochs=350,
          batch_size=128,
          is_train=True,
          net_st=None,
          beta=0.0,
          lrd=10):
    """
    Train the model.
    :param net: model to be trained
    :param lr: learning rate
    :param optimizer_option: optimizer type
    :param data: datasets used to train
    :param epochs: number of training epochs
    :param batch_size: batch size
    :param is_train: whether it is a training process
    :param net_st: uncompressed model
    :param beta: transfer parameter
    """

    net.train()
    if net_st != None:
        net_st.eval()

    if data == 'cifar10':
        trainloader = load_cifar10(is_train, batch_size)
        valloader = load_cifar10(False, batch_size)
    elif data == 'cifar100':
        trainloader = load_cifar100(is_train, batch_size)
        valloader = load_cifar100(False, batch_size)
    elif data == 'svhn':
        trainloader = load_svhn(is_train, batch_size)
        valloader = load_svhn(False, batch_size)
    elif data == 'mnist':
        trainloader = load_mnist(is_train, batch_size)
    elif data == 'tinyimagenet':
        trainloader, valloader = load_tiny_imagenet(is_train, batch_size)
    else:
        exit()

    criterion = nn.CrossEntropyLoss()
    criterion_mse = nn.MSELoss()
    optimizer = get_optimizer(net, lr, optimizer_option)

    start_time = time.time()
    last_time = 0

    best_acc = 0
    best_param = net.state_dict()

    iteration = 0
    for epoch in range(epochs):
        print("****************** EPOCH = %d ******************" % epoch)
        if log != None:
            log.write("****************** EPOCH = %d ******************\n" %
                      epoch)

        total = 0
        correct = 0
        loss_sum = 0

        # change learning rate
        if epoch == 150 or epoch == 250:
            lr = adjust_lr(lr, lrd=lrd, log=log)
            optimizer = get_optimizer(net, lr, optimizer_option)

        for i, data in enumerate(trainloader, 0):
            iteration += 1

            # foward
            inputs, labels = data
            inputs_V, labels_V = Variable(inputs.cuda()), Variable(
                labels.cuda())
            outputs, outputs_conv = net(inputs_V)
            loss = criterion(outputs, labels_V)
            if net_st != None:
                outputs_st, outputs_st_conv = net_st(inputs_V)
                # loss += beta * transfer_loss(outputs_conv, outputs_st_conv)
                for i in range(len(outputs_st_conv)):
                    # print("!!!!! %d" % i)
                    if i != (len(outputs_st_conv) - 1):
                        loss += beta / 50 * criterion_mse(
                            outputs_conv[i], outputs_st_conv[i].detach())
                    else:
                        loss += beta * criterion_mse(
                            outputs_conv[i], outputs_st_conv[i].detach())

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(F.softmax(outputs, -1), 1)
            total += labels_V.size(0)
            correct += (predicted == labels_V).sum()
            loss_sum += loss

            if iteration % 100 == 99:
                now_time = time.time()
                print('accuracy: %f %%; loss: %f; time: %ds' %
                      ((float(100) * float(correct) / float(total)), loss,
                       (now_time - last_time)))
                if log != None:
                    log.write(
                        'accuracy: %f %%; loss: %f; time: %ds\n' %
                        ((float(100) * float(correct) / float(total)), loss,
                         (now_time - last_time)))

                total = 0
                correct = 0
                loss_sum = 0
                last_time = now_time

        # validation
        if data == 'tinyimagenet':
            if epoch % 10 == 9:
                net.eval()
                val_acc = validation(net, valloader, log)
                net.train()
                if val_acc > best_acc:
                    best_acc = val_acc
                    best_param = net.state_dict()
        else:
            if epoch % 10 == 9:
                best_param = net.state_dict()
                net.eval()
                validation(net, valloader, log)
                net.train()

    print('Finished Training. It took %ds in total' %
          (time.time() - start_time))
    if log != None:
        log.write('Finished Training. It took %ds in total\n' %
                  (time.time() - start_time))
    return best_param
Example #4
0
net.add_block(layers.MaxPool2d(stride=2))
net.add_block(layers.Linear(n_in=8 * 8 * 16, n_out=512))
net.add_block(layers.BatchNorm1d(n_in=512))
net.add_block(activations.ReLU(inplace=True))
net.add_block(layers.Dropout(keep_prob=0.5))
net.add_block(layers.Linear(n_in=512, n_out=10))
net.add_block(losses.CrossEntropyLoss())

# optimizer
optimizer = optimizers.Momentum(learning_rate=learning_rate,
                                momentum=0.9,
                                Nesterov=True,
                                lr_decay=lr_decay)

# data loading and augmentation
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
train_transform = Transforms([
    ToTensor(),
    Pad(4),
    RandomCrop(32),
    RandomHorizontalFlip(),
    Normalize(mean, std)
])
val_test_transform = Transforms([ToTensor(), Normalize(mean, std)])
train, val, test = load_data.load_cifar10(train_transform, val_test_transform)
net.load_data(train, val, test)

# training
net.train(optimizer, num_epoches, batch_size, eval_freq)
Example #5
0
z_dim = 12
n_group_per_scale = [1, 1, 1, 1]
n_ch = 64  #128
k_size = 3
gamma = 0.01
x = layers.Input(shape=input_shape)
hfvae, encoder, latent_variables, latent_stats, delta_stats = HFVAE(
    x, z_dim, n_group_per_scale, n_ch, k_size)
decoder = Decoder(z_dim, n_group_per_scale, n_ch)

hfvae.summary()

hfvae.compile(optimizer=Adam(learning_rate=0.00001), loss=[], metrics=['mse'])

# Load data
x_train, x_test = load_data.load_cifar10()

weights_filename = 'hfvae_64.h5'

TRAIN = True

epochs = 0
batch_size = 100

if TRAIN:
    hfvae.load_weights('saved_weights/' + weights_filename)
    hfvae.fit(x_train,
              x_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1)