示例#1
0
def main():
    '''
    Initialize everything and train
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp', type=str, default='default')
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--iterations', type=int, default=10000)
    parser.add_argument('--dataset', type=str, default='omniglot')
    parser.add_argument('--num_cls', type=int, default=5)
    parser.add_argument('--num_samples', type=int, default=1)
    parser.add_argument('--num_eval_samples', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--task_shuffling', type=str, default='intertask')
    parser.add_argument('--dynamic_k', action='store_true')
    options = parser.parse_args()

    if not os.path.exists(options.exp):
        os.makedirs(options.exp)

    if torch.cuda.is_available() and not options.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader = init_dataset(
        options)
    model = init_model(options)
    optim = torch.optim.Adam(params=model.parameters(), lr=options.lr)
    res = train(opt=options,
                tr_dataloader=tr_dataloader,
                val_dataloader=val_dataloader,
                model=model,
                optim=optim)
    best_state, best_acc, train_loss, train_acc, val_loss, val_acc = res
    print('Testing with last model..')
    test(opt=options, test_dataloader=test_dataloader, model=model)

    model.load_state_dict(best_state)
    print('Testing with best model..')
    test(opt=options, test_dataloader=test_dataloader, model=model)
示例#2
0
def main():
    args = utils.get_args()

    # Settings
    BATCH_SIZE = 128
    NB_EPOCHS = args.epochs  # default 25
    LRN_RATE = 0.0001 / 28

    if args.verbose:
        PRINT_DELAY = 1
    else:
        PRINT_DELAY = 500

    # if running on server (MILA), copy dataset locally
    dataset_path = utils.init_dataset(args, 'mnist')

    # load dataset
    data = load_mnist_dataset(dataset_path, args)
    train_feats, train_targets, valid_feats, valid_targets, test_feats, test_targets = data

    # Get theano functions
    predict, preds_grad, network = build_theano_fn()

    print 'Starting training...'

    hist_errors = []
    running_err_avg = 0
    running_last_err_avg = 0

    # Create mask that will be used to create sequences
    mask = np.tril(np.ones((784, 784), dtype=theano.config.floatX), 0)

    for i in xrange(NB_EPOCHS):

        print 'Epoch #%s of %s' % ((i + 1), NB_EPOCHS)

        epoch_err = []
        num_batch = 0
        t_epoch = time.time()

        # iterate over minibatches for training
        schemes_train = ShuffledScheme(examples=train_feats.shape[0], batch_size=1)

        # We deal with 1 example as if it was an episode
        for batch_idx in schemes_train.get_request_iterator():

            batch_delta_params = []
            batch_error = []
            batch_grads = []

            num_batch += 1
            t_batch = time.time()

            train_x = train_feats[batch_idx]
            true_y = train_targets[batch_idx]

            nb_seq = 0

            for t in xrange(784):

                # apply mask at fixed interval, making sequences of pixels appear
                if (t + 1) % 28 == 0:

                    nb_seq += 1

                    seq_x = train_x * mask[t]
                    pred = predict(seq_x)[0, 0]
                    grad = preds_grad(seq_x)

                    # if nb_seq == 1:
                    #     sum_grad = np.copy(grad)
                    # else:
                        # sum_grad += np.copy(grad)

                    if t < 783:
                        seq_x_prime = train_x * mask[t + 28]
                        pred_y_prime = predict(seq_x_prime)[0, 0]
                        TD_error = (pred_y_prime - pred)
                        error = (true_y - pred)
                    else:
                        TD_error = (true_y - pred)
                        error = (true_y - pred)

                    param_values = lyr.get_all_param_values(network)

                    delta_params = LRN_RATE * TD_error * grad
                    param_values += delta_params
                    lyr.set_all_param_values(network, param_values)

                    batch_error.append(error)

            # update params based on experience

            last_error = np.abs(error)[0]
            sqrd_error = np.linalg.norm(batch_error, 2)
            epoch_err.append(sqrd_error)
            running_err_avg = 0.05 * sqrd_error + 0.95 * running_err_avg
            running_last_err_avg = 0.05 * last_error + 0.95 * running_last_err_avg

            if num_batch % PRINT_DELAY == 0:
                print '- batch %s, err %s (avg %s), last %s (avg %s), in %s sec' % (num_batch,  np.round(sqrd_error, 4),
                                                                                        np.round(running_err_avg, 4),
                                                                                        np.round(last_error, 4),
                                                                                        np.round(running_last_err_avg, 4),
                                                                                        np.round(time.time() - t_batch, 2))

        print '- Epoch train (err %s) in %s sec' % (epoch_err, round(time.time() - t_epoch))

        # hist_errors.append(epoch_err)
        utils.dump_objects_output(args, epoch_err, 'epoch_%s_error_sqrd_norm.pkl' % (i + 1))
示例#3
0
def main():
    args = utils.get_args()

    NB_GEN = args.gen  # default 5
    if args.reload is not None:
        RELOAD_SRC = args.reload[0]
        RELOAD_ID = args.reload[1]

    # if running on server (MILA), copy dataset locally
    dataset_path = utils.init_dataset(args, 'mscoco_inpainting/preprocessed')
    valid_path = os.path.join(dataset_path, 'val2014')

    if args.captions:
        t = time.time()
        embedding_model = utils.init_google_word2vec_model(args)
        print 'Embedding model was loaded in %s secs' % np.round(
            time.time() - t, 0)

    # build network and get theano functions for training
    theano_fn = dcgan.gen_theano_fn(args)
    train_discr, train_gen, predict, reconstr_fn, reconstr_noise_shrd, model = theano_fn

    # get different file names for the split data set
    valid_files = utils.get_preprocessed_files(valid_path)
    valid_full_files, valid_cter_files, valid_capt_files = valid_files

    NB_VALID_FILES = len(valid_full_files)

    corruption_mask = utils.get_corruption_mask()

    if args.reload is not None:

        # Reload previously saved model
        discriminator, generator = model
        file_discr = 'discrminator_epoch_%s.pkl' % RELOAD_ID
        file_gen = 'generator_epoch_%s.pkl' % RELOAD_ID
        t_load = time.time()
        loaded_discr = utils.reload_model(args, discriminator, file_discr,
                                          RELOAD_SRC)
        loaded_gen = utils.reload_model(args, generator, file_gen, RELOAD_SRC)

        if loaded_discr and loaded_gen:

            if args.verbose:
                print 'models loaded in %s sec' % (round(
                    time.time() - t_load, 0))

            # choose random valid file
            file_id = np.random.choice(NB_VALID_FILES, 1)

            # load file
            with open(valid_full_files[file_id], 'r') as f:
                valid_full = np.load(f).astype(theano.config.floatX)

            if args.captions:
                # load file with the captions
                with open(valid_capt_files[file_id], 'rb') as f:
                    valid_capt = pkl.load(f)

            t_load = time.time()

            if args.verbose:
                print 'file %s loaded in %s sec' % (
                    valid_full_files[file_id], round(time.time() - t_load, 0))

            # pick a given number of images from that file
            batch_valid = np.random.choice(len(valid_full),
                                           NB_GEN,
                                           replace=False)

            # reconstruct image
            img_uncorrpt = valid_full[batch_valid]

            if args.captions:
                captions = utils.captions_to_embedded_matrix(
                    embedding_model, batch_valid, valid_capt)
                img_reconstr = reconstruct_img(args, img_uncorrpt,
                                               corruption_mask, reconstr_fn,
                                               reconstr_noise_shrd, captions)
            else:
                img_reconstr = reconstruct_img(args, img_uncorrpt,
                                               corruption_mask, reconstr_fn,
                                               reconstr_noise_shrd)

            # save images
            for i, images_reconstr in enumerate(img_reconstr):
                utils.save_pics_gan(
                    args,
                    images_reconstr,
                    'pred_rload_%s_%s_caption_%s_copy_%s' %
                    (RELOAD_SRC, RELOAD_ID, args.captions, i + 1),
                    show=False,
                    save=True,
                    tanh=False)
            utils.save_pics_gan(args,
                                img_uncorrpt,
                                'true_rload_%s_%s_caption_%s' %
                                (RELOAD_SRC, RELOAD_ID, args.captions),
                                show=False,
                                save=True,
                                tanh=False)

            if args.captions:
                save_code = 'rload_%s_%s' % (RELOAD_SRC, RELOAD_ID)
                utils.save_captions(args, save_code, valid_capt, batch_valid)

            if args.mila:
                utils.move_results_from_local()
def main():
    args = utils.get_args()

    # Settings for training
    BATCH_SIZE = 128
    NB_EPOCHS = args.epochs  # default 25
    NB_GEN = args.gen  # default 5
    TRAIN_STEPS_DISCR = 15
    TRAIN_STEPS_GEN = 10
    if args.reload is not None:
        RELOAD_SRC = args.reload[0]
        RELOAD_ID = args.reload[1]

    if args.verbose:
        BATCH_PRINT_DELAY = 1
    else:
        BATCH_PRINT_DELAY = 100

    # if running on server (MILA), copy dataset locally
    dataset_path = utils.init_dataset(args, 'mscoco_inpainting/preprocessed')
    train_path = os.path.join(dataset_path, 'train2014')
    valid_path = os.path.join(dataset_path, 'val2014')

    if args.captions:
        t = time.time()
        embedding_model = utils.init_google_word2vec_model(args)
        print 'Embedding model was loaded in %s secs' % np.round(
            time.time() - t, 0)

    # build network and get theano functions for training
    theano_fn = gen_theano_fn(args)
    train_discr, train_gen, predict, reconstr_fn, reconstr_noise_shrd, model = theano_fn

    # get different file names for the split data set
    train_files = utils.get_preprocessed_files(train_path)
    train_full_files, train_cter_files, train_capt_files = train_files

    valid_files = utils.get_preprocessed_files(valid_path)
    valid_full_files, valid_cter_files, valid_capt_files = valid_files

    NB_TRAIN_FILES = len(train_full_files)
    NB_VALID_FILES = len(valid_full_files)

    print 'Starting training...'

    train_loss = []

    if args.reload is not None:
        discriminator, generator = model
        file_discr = 'discrminator_epoch_%s.pkl' % RELOAD_ID
        file_gen = 'generator_epoch_%s.pkl' % RELOAD_ID
        loaded_discr = utils.reload_model(args, discriminator, file_discr,
                                          RELOAD_SRC)
        loaded_gen = utils.reload_model(args, generator, file_gen, RELOAD_SRC)

    for i in xrange(NB_EPOCHS):

        print 'Epoch #%s of %s' % ((i + 1), NB_EPOCHS)

        epoch_acc = 0
        epoch_loss = 0
        num_batch = 0
        t_epoch = time.time()
        d_batch_loss = 0
        g_batch_loss = 0
        steps_loss_g = []  # will store every loss of generator
        steps_loss_d = []  # will store every loss of discriminator
        d_train_step = 0

        # iterate of split datasets
        for file_id in np.random.choice(NB_TRAIN_FILES,
                                        NB_TRAIN_FILES,
                                        replace=False):

            t_load = time.time()

            # load file with full image
            with open(train_full_files[file_id], 'r') as f:
                train_full = np.load(f).astype(theano.config.floatX)

            if args.captions:
                # load file with the captions
                with open(train_capt_files[file_id], 'rb') as f:
                    train_capt = pkl.load(f)

            if args.verbose:
                print 'file %s loaded in %s sec' % (
                    train_full_files[file_id], round(time.time() - t_load, 0))

            # iterate over minibatches for training
            schemes_train = ShuffledScheme(examples=len(train_full),
                                           batch_size=BATCH_SIZE)

            for batch_idx in schemes_train.get_request_iterator():

                d_train_step += 1

                t_batch = time.time()
                # generate batch of uniform samples
                rdm_d = np.random.uniform(-1., 1., size=(len(batch_idx), 100))
                rdm_d = rdm_d.astype(theano.config.floatX)

                # train with a minibatch on discriminator
                if args.captions:
                    # generate embeddings for the batch
                    d_capts_batch = utils.captions_to_embedded_matrix(
                        embedding_model, batch_idx, train_capt)

                    d_batch_loss = train_discr(train_full[batch_idx], rdm_d,
                                               d_capts_batch)

                else:
                    d_batch_loss = train_discr(train_full[batch_idx], rdm_d)

                steps_loss_d.append(d_batch_loss)
                steps_loss_g.append(g_batch_loss)

                if num_batch % BATCH_PRINT_DELAY == 0:
                    print '- train discr batch %s, loss %s in %s sec' % (
                        num_batch, np.round(d_batch_loss, 4),
                        np.round(time.time() - t_batch, 2))

                # check if it is time to train the generator
                if d_train_step >= TRAIN_STEPS_DISCR:

                    # reset discriminator step counter
                    d_train_step = 0

                    # train the generator for given number of steps
                    for _ in xrange(TRAIN_STEPS_GEN):

                        # generate batch of uniform samples
                        rdm_g = np.random.uniform(-1.,
                                                  1.,
                                                  size=(BATCH_SIZE, 100))
                        rdm_g = rdm_g.astype(theano.config.floatX)

                        # train with a minibatch on generator
                        if args.captions:
                            # sample a random set of captions from current training file
                            g_batch_idx = np.random.choice(len(train_full),
                                                           BATCH_SIZE,
                                                           replace=False)
                            g_capts_batch = utils.captions_to_embedded_matrix(
                                embedding_model, g_batch_idx, train_capt)

                            g_batch_loss = train_gen(rdm_g, g_capts_batch)
                        else:
                            g_batch_loss = train_gen(rdm_g)

                        steps_loss_d.append(d_batch_loss)
                        steps_loss_g.append(g_batch_loss)

                        if num_batch % BATCH_PRINT_DELAY == 0:
                            print '- train gen step %s, loss %s' % (
                                _ + 1, np.round(g_batch_loss, 4))

                epoch_loss += d_batch_loss + g_batch_loss
                num_batch += 1

        train_loss.append(np.round(epoch_loss, 4))

        if args.save > 0 and i % args.save == 0:
            discriminator, generator = model
            utils.save_model(args, discriminator,
                             'discrminator_epoch_%s.pkl' % i)
            utils.save_model(args, generator, 'generator_epoch_%s.pkl' % i)

        print '- Epoch train (loss %s) in %s sec' % (
            train_loss[i], round(time.time() - t_epoch))

        # save losses at each step
        utils.dump_objects_output(args, (steps_loss_d, steps_loss_g),
                                  'steps_loss_epoch_%s.pkl' % i)

    print 'Training completed.'

    # Generate images out of pure noise with random captions (if applicable from valid)
    if NB_GEN > 0:

        if args.reload is not None:
            assert loaded_gen and loaded_discr, 'An error occured during loading, cannot generate.'
            save_code = 'rload_%s_%s' % (RELOAD_SRC, RELOAD_ID)
        else:
            save_code = 'no_reload'

        rdm_noise = np.random.uniform(-1., 1., size=(NB_GEN, 100))
        rdm_noise = rdm_noise.astype(theano.config.floatX)

        # choose random valid file
        file_id = np.random.choice(NB_VALID_FILES, 1)

        # load file
        with open(valid_full_files[file_id], 'r') as f:
            valid_full = np.load(f).astype(theano.config.floatX)

        if args.captions:

            # load file with the captions
            with open(valid_capt_files[file_id], 'rb') as f:
                valid_capt = pkl.load(f)

            # pick a given number of images from that file
            batch_valid = np.random.choice(len(valid_capt),
                                           NB_GEN,
                                           replace=False)
            captions = utils.captions_to_embedded_matrix(
                embedding_model, batch_valid, valid_capt)
            # captions = np.empty((NB_GEN, 300), dtype=theano.config.floatX)  # used for debugging

            # make predictions
            imgs_noise, probs_noise = predict(rdm_noise, captions)
        else:
            # make predictions
            imgs_noise, probs_noise = predict(rdm_noise)

        if args.verbose:
            print probs_noise

        # save images
        true_imgs = valid_full[batch_valid]

        utils.save_pics_gan(args,
                            imgs_noise,
                            'noise_caption_%s' % args.captions + save_code,
                            show=False,
                            save=True,
                            tanh=False)
        utils.save_pics_gan(args,
                            true_imgs,
                            'true_caption_%s' % args.captions + save_code,
                            show=False,
                            save=True,
                            tanh=False)

        if args.captions:
            utils.save_captions(args, save_code, valid_capt, batch_valid)

    if args.mila:
        utils.move_results_from_local()
示例#5
0
from snail import SnailFewShot
from utils import init_dataset
from train import batch_for_few_shot

parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='default')
parser.add_argument('--iterations', type=int, default=10000)
parser.add_argument('--dataset', type=str, default='mini_imagenet')
parser.add_argument('--num_cls', type=int, default=5)
parser.add_argument('--num_samples', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--cuda', action='store_true')
options = parser.parse_args()

model = SnailFewShot(5, 1, 'mini_imagenet')
weights = torch.load('mini_imagenet_5way_1shot/best_model.pth')
model.load_state_dict(weights)
model = model.cuda()
_, val_dataloader, _, _ = init_dataset(options)
val_iter = iter(val_dataloader)
for batch in val_iter:
    x, y = batch
    x, y, last_targets = batch_for_few_shot(options, x, y)
    model_output = model(x, y)
    last_model = model_output[:, -1, :]

    print(last_model)
    print(last_targets)
    import pdb
    pdb.set_trace()
import numpy as np
import fuel
import utils
import os
import cPickle as pkl
import time
from fuel.schemes import ShuffledScheme

if __name__ == '__main__':

    args = utils.get_args()

    # if running on server (MILA), copy dataset locally
    dataset_path = utils.init_dataset(args, 'mscoco_inpainting')
    train_path = os.path.join(dataset_path, 'preprocessed/train2014')
    valid_path = os.path.join(dataset_path, 'preprocessed/val2014')

    NB_TRAIN = 82782
    NB_VALID = 40504
    MAX_IMGS = 15000

    train_count = 0
    valid_count = 0

    # process for training dataset
    schemes_train = ShuffledScheme(examples=NB_TRAIN, batch_size=MAX_IMGS)

    for i, batch_idx in enumerate(schemes_train.get_request_iterator()):

        print 'train file %s in progress...' % i