def plot_vars(f_step, projection, load_all=False):
    # The one employed for the figure name when exported
    variable_name = 'gph_t_850'
    # Build the name of the output image
    run_string, _ = get_run()

    if load_all:
        f_steps = list(range(0, 79)) + list(range(81, 121, 3))
    else:
        f_steps = [f_step]

    filenames = ['/tmp/' + projection + '_' + variable_name +
                 '_%s_%03d.png' % (run_string, f_step) for f_step in f_steps]
    test_filenames = [os.path.exists(f) for f in filenames]

    if all(test_filenames):  # means the files already exist
        return filenames

    # otherwise do the plots
    dset = get_dset(vars_3d=['t@850', 'fi@500'], f_times=f_steps).squeeze()
    # Add a fictictious 1-D time dimension just to avoid problems
    if 'step' not in dset.dims.keys():
        dset = dset.expand_dims('step')
    #
    dset = subset_arrays(dset, projection)
    time = pd.to_datetime(dset.valid_time.values)
    cum_hour = dset.step.values.astype(int)

    temp_850 = dset['t'] - 273.15
    z_500 = dset['z']
    gph_500 = mpcalc.geopotential_to_height(z_500)
    gph_500 = xr.DataArray(gph_500.magnitude, coords=z_500.coords,
                           attrs={'standard_name': 'geopotential height',
                                  'units': gph_500.units})

    levels_temp = np.arange(-30., 30., 1.)
    levels_gph = np.arange(4700., 6000., 70.)

    lon, lat = get_coordinates(temp_850)
    lon2d, lat2d = np.meshgrid(lon, lat)

    cmap = get_colormap('temp')

    args = dict(filenames=filenames, projection=projection, levels_temp=levels_temp,
                cmap=cmap, lon2d=lon2d, lat2d=lat2d, lon=lon, lat=lat, temp_850=temp_850.values,
                gph_500=gph_500.values, levels_gph=levels_gph, time=time, run_string=run_string)

    if load_all:
        single_plot_param = partial(single_plot, **args)
        iterator = range(0, len(f_steps))
        pool = Pool(cpu_count())
        results = pool.map(single_plot_param, iterator)
        pool.close()
        pool.join()
    else:
        results = single_plot(0, **args)

    return results
def update_output(n_clicks, chart, projection):
    run_string, _ = get_run()

    if n_clicks > 1:
      f_steps = list(range(0, 79)) + list(range(81, 121, 3))
      filenames = ['/tmp/' + projection + '_' + chart + '_%s_%03d.png' % (run_string, f_step) for f_step in f_steps]
      test_filenames = [os.path.exists(f) for f in filenames]

      if all(test_filenames): # means the files already exist
        return None
      else:
        none = plot_vars(f_steps, projection, load_all=True)
        return None
def update_figure(chart, f_step, projection):
  run_string, _ = get_run()

  filename = '/tmp/' + projection + '_' + chart + '_%s_%03d.png' % (run_string, f_step)

  if os.path.exists(filename):
    out = b64_image(filename)
  else:
    filename_fig = plot_vars(f_step, projection, load_all=False)
    assert filename_fig == filename, "Mismatching filename strings! From plot_vars:%s , defined:%s" % (filename_fig, filename)
    out = b64_image(filename_fig)

  return out
from torchvision import datasets

import configs.pwb as params
import model.g_and_t_model as model
import trainer.g_and_t_trainer as trainer
import utils
from experiment.experiment import experiment
from utils.g_and_t_utils import *

p = params.PwbParameters()
total_seeds = len(p.parse_known_args()[0].seed)
rank = p.parse_known_args()[0].rank
all_args = vars(p.parse_known_args()[0])
print("All hyperparameters = ", all_args)

flags = utils.get_run(vars(p.parse_known_args()[0]), rank)

utils.set_seed(flags["seed"])

my_experiment = experiment(flags["name"],
                           flags,
                           flags['output_dir'],
                           commit_changes=False,
                           rank=int(rank / total_seeds),
                           seed=total_seeds)

my_experiment.results["all_args"] = all_args

logger = logging.getLogger('experiment')

logger.info("Selected hyperparameters %s", str(flags))
Beispiel #5
0
def main():
    p = class_parser_eval.Parser()
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1)

    data_train = df.DatasetFactory.get_dataset("omniglot", train=True, background=False, path=args['path'])
    data_test = df.DatasetFactory.get_dataset("omniglot", train=False, background=False, path=args['path'])
    final_results_train = []
    final_results_test = []
    lr_sweep_results = []

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    config = mf.ModelFactory.get_model("na", args['dataset'], output_dimension=1000)

    maml = load_model(args, config)
    maml = maml.to(device)

    args['schedule'] = [int(x) for x in args['schedule'].split(":")]
    no_of_classes_schedule = args['schedule']
    print(args["schedule"])
    for total_classes in no_of_classes_schedule:
        lr_sweep_range = [0.03, 0.01, 0.003,0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        lr_all = []
        for lr_search_runs in range(0, 5):

            classes_to_keep = np.random.choice(list(range(650)), total_classes, replace=False)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset, False, classes=no_of_classes_schedule),
                batch_size=1,
                shuffle=args['iid'], num_workers=2)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                         shuffle=False, num_workers=1)

            max_acc = -1000
            for lr in lr_sweep_range:

                maml.reset_vars()

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                train_iterator(iterator_sorted, device, maml, opt)

                correct = eval_iterator(iterator_train, device, maml)
                if (correct > max_acc):
                    max_acc = correct
                    max_lr = lr

            lr_all.append(max_lr)
            results_mem_size = (max_acc, max_lr)
            lr_sweep_results.append((total_classes, results_mem_size))

            my_experiment.results["LR Search Results"] = lr_sweep_results
            my_experiment.store_json()
            logger.debug("LR RESULTS = %s", str(lr_sweep_results))

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])

        logger.info("BEST LR %s= ", str(best_lr))

        for current_run in range(0, args['runs']):

            classes_to_keep = np.random.choice(list(range(650)), total_classes, replace=False)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)

            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset, False, classes=no_of_classes_schedule),
                batch_size=1,
                shuffle=args['iid'], num_workers=2)

            dataset = utils.remove_classes_omni(data_test, classes_to_keep)
            iterator_test = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                        shuffle=False, num_workers=1)

            dataset = utils.remove_classes_omni(data_train, classes_to_keep)
            iterator_train = torch.utils.data.DataLoader(dataset, batch_size=32,
                                                         shuffle=False, num_workers=1)

            lr = best_lr

            maml.reset_vars()

            opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

            train_iterator(iterator_sorted, device,maml, opt)

            logger.info("Result after one epoch for LR = %f", lr)

            correct = eval_iterator(iterator_train, device, maml)

            correct_test = eval_iterator(iterator_test, device, maml)

            results_mem_size = (correct, best_lr, "train")
            logger.info("Final Max Result train = %s", str(correct))
            final_results_train.append((total_classes, results_mem_size))

            results_mem_size = (correct_test, best_lr, "test")
            logger.info("Final Max Result test= %s", str(correct_test))
            final_results_test.append((total_classes, results_mem_size))

            my_experiment.results["Final Results"] = final_results_train
            my_experiment.results["Final Results Test"] = final_results_test
            my_experiment.store_json()
            logger.debug("FINAL RESULTS = %s", str(final_results_train))
            logger.debug("FINAL RESULTS = %s", str(final_results_test))
Beispiel #6
0
def main():
    p = class_parser_eval.Parser()
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=0,
                               seed=1)

    final_results_all = []
    temp_result = []
    args['schedule'] = [int(x) for x in args['schedule'].split(":")]
    total_clases = args['schedule']
    print(args["schedule"])
    for tot_class in total_clases:
        print("Classes current step = ", tot_class)
        lr_list = [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]
        lr_all = []
        for lr_search in range(0, 5):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False,
                                              path=args['path']), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args['iid'],
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args['test'],
                                              background=False,
                                              path=args['path']), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=1)

            gpu_to_use = rank % args["gpus"]
            if torch.cuda.is_available():
                device = torch.device('cuda:' + str(gpu_to_use))
                logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
            else:
                device = torch.device('cpu')

            config = mf.ModelFactory.get_model("na",
                                               args['dataset'],
                                               output_dimension=1000)
            max_acc = -1000
            for lr in lr_list:

                print(lr)
                maml = load_model(args, config)
                maml = maml.to(device)

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                for _ in range(0, 1):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)

                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)
                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

            lr_all.append(max_lr)
            logger.info("Final Max Result = %s", str(max_acc))
            results_mem_size = (max_acc, max_lr)
            temp_result.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Temp Results = %s", str(results_mem_size))

            my_experiment.results["Temp Results"] = temp_result
            my_experiment.store_json()
            print("LR RESULTS = ", temp_result)

        from scipy import stats
        best_lr = float(stats.mode(lr_all)[0][0])

        logger.info("BEST LR %s= ", str(best_lr))

        for aoo in range(0, args['runs']):

            keep = np.random.choice(list(range(650)), tot_class, replace=False)

            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=True,
                                              background=False), keep)
            iterator_sorted = torch.utils.data.DataLoader(
                utils.iterator_sorter_omni(dataset,
                                           False,
                                           classes=total_clases),
                batch_size=1,
                shuffle=args['iid'],
                num_workers=2)
            dataset = utils.remove_classes_omni(
                df.DatasetFactory.get_dataset("omniglot",
                                              train=not args['test'],
                                              background=False), keep)
            iterator = torch.utils.data.DataLoader(dataset,
                                                   batch_size=32,
                                                   shuffle=False,
                                                   num_workers=1)

            for mem_size in [args['memory']]:
                max_acc = -10
                max_lr = -10

                lr = best_lr

                # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                maml = load_model(args, config)
                maml = maml.to(device)

                opt = torch.optim.Adam(maml.get_adaptation_parameters(), lr=lr)

                for _ in range(0, 1):
                    for img, y in iterator_sorted:
                        img = img.to(device)
                        y = y.to(device)

                        pred = maml(img)
                        opt.zero_grad()
                        loss = F.cross_entropy(pred, y)
                        loss.backward()
                        opt.step()

                logger.info("Result after one epoch for LR = %f", lr)
                correct = 0
                for img, target in iterator:
                    img = img.to(device)
                    target = target.to(device)
                    logits_q = maml(img,
                                    vars=None,
                                    bn_training=False,
                                    feature=False)

                    pred_q = (logits_q).argmax(dim=1)

                    correct += torch.eq(pred_q, target).sum().item() / len(img)

                logger.info(str(correct / len(iterator)))
                if (correct / len(iterator) > max_acc):
                    max_acc = correct / len(iterator)
                    max_lr = lr

                lr_list = [max_lr]
                results_mem_size = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)
Beispiel #7
0
def main():
    p = params.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    rank = p.parse_known_args()[0].rank
    all_args = vars(p.parse_known_args()[0])
    print("All args = ", all_args)

    args = utils.get_run(vars(p.parse_known_args()[0]), rank)

    utils.set_seed(args['seed'])

    my_experiment = experiment(args['name'],
                               args,
                               "../results/",
                               commit_changes=False,
                               rank=0,
                               seed=1)

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device('cuda:' + str(gpu_to_use))
        logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
    else:
        device = torch.device('cpu')

    dataset = df.DatasetFactory.get_dataset(args['dataset'],
                                            background=True,
                                            train=True,
                                            path=args["path"],
                                            all=True)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=256,
                                           shuffle=True,
                                           num_workers=0)

    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args["dataset"])

    maml = learner.Learner(config).to(device)

    for k, v in maml.named_parameters():
        print(k, v.requires_grad)

    opt = torch.optim.Adam(maml.parameters(), lr=args["lr"])

    for e in range(args["epoch"]):
        correct = 0
        for img, y in tqdm(iterator):
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)

            opt.zero_grad()
            loss = F.cross_entropy(pred, y.long())
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).sum().float() / len(y)
        logger.info("Accuracy at epoch %d = %s", e,
                    str(correct / len(iterator)))
        torch.save(maml, my_experiment.path + "model.net")
def plot_var(f_step, projection):
    # NOTE!
    # If we are inside this function it means that the picture does not exist
    # The one employed for the figure name when exported
    variable_name = 'gph_t_850'
    # Build the name of the output image
    run_string, _ = get_run()
    filename = '/tmp/' + projection + '_' + \
        variable_name + '_%s_%03d.png' % (run_string, f_step)

    """In the main function we basically read the files and prepare the variables to be plotted.
  This is not included in utils.py as it can change from case to case."""
    dset = get_dset(vars_3d=['t@850', 'fi@500'], f_times=f_step).squeeze()
    dset = subset_arrays(dset, projection)
    time = pd.to_datetime(dset.valid_time.values)
    cum_hour = dset.step.values.astype(int)

    temp_850 = dset['t'] - 273.15
    z_500 = dset['z']
    gph_500 = mpcalc.geopotential_to_height(z_500)
    gph_500 = xr.DataArray(gph_500.magnitude, coords=z_500.coords,
                           attrs={'standard_name': 'geopotential height',
                                  'units': gph_500.units})

    levels_temp = np.arange(-30., 30., 1.)
    levels_gph = np.arange(4700., 6000., 70.)

    cmap = get_colormap('temp')

    fig = plt.figure(figsize=(figsize_x, figsize_y))

    ax = plt.gca()

    lon, lat = get_coordinates(temp_850)
    lon2d, lat2d = np.meshgrid(lon, lat)

    ax = get_projection_cartopy(plt, projection, compute_projection=True)

    if projection == 'euratl':
        norm = BoundaryNorm(levels_temp, ncolors=cmap.N)
        cs = ax.pcolormesh(lon2d, lat2d, temp_850, cmap=cmap, norm=norm)
    else:
        cs = ax.contourf(lon2d, lat2d, temp_850, extend='both',
                         cmap=cmap, levels=levels_temp)

    c = ax.contour(lon2d, lat2d, gph_500, levels=levels_gph,
                   colors='white', linewidths=1.)

    labels = ax.clabel(c, c.levels, inline=True, fmt='%4.0f', fontsize=6)

    maxlabels = plot_maxmin_points(ax, lon, lat, gph_500,
                                   'max', 80, symbol='H', color='royalblue', random=True)
    minlabels = plot_maxmin_points(ax, lon, lat, gph_500,
                                   'min', 80, symbol='L', color='coral', random=True)

    an_fc = annotation_forecast(ax, time)
    an_var = annotation(
        ax, 'Geopotential height @500hPa [m] and temperature @850hPa [C]', loc='lower left', fontsize=6)
    an_run = annotation_run(ax, time)

    plt.colorbar(cs, orientation='horizontal',
                 label='Temperature', pad=0.03, fraction=0.04)

    plt.savefig(filename, **options_savefig)
    plt.clf()

    return filename
Beispiel #9
0
    print("Reading benchmark dataset.")
    (index_sets,
     index_keys) = read_sets_from_file(args.index_set_file,
                                       sample_ratio=args.index_sample_ratio,
                                       skip=1)
    (query_sets,
     query_keys) = read_sets_from_file(args.index_set_file,
                                       sample_ratio=args.query_sample_ratio,
                                       skip=1)

    # Initialize output SQLite database.
    init_results_db(args.output)

    # Run ground truth.
    params = {"benchmark": benchmark_settings}
    if get_run("ground_truth", k, None, params, args.output) is None:
        print("Running Ground Truth.")
        ground_truth_results, ground_truth_times = search_jaccard_topk(
            (index_sets, index_keys), (query_sets, query_keys), k)
        save_results("ground_truth", k, None, params, ground_truth_results,
                     ground_truth_times, args.output)

    # Run HNSW
    for M in Ms:
        for efC in efCs:
            index_params = {
                'M': M,
                'indexThreadQty': num_threads,
                'efConstruction': efC,
                'post': 0,
            }
def main():
    p = params.Parser()
    total_seeds = len(p.parse_known_args()[0].seed)
    _args = p.parse_args()
    # rank = p.parse_known_args()[0].rank
    rank = _args.rank
    # all_args = vars(p.parse_known_args()[0])
    print("All args = ", _args)

    args = utils.get_run(vars(_args), rank)

    utils.set_seed(args["seed"])

    if args["log_root"]:
        log_root = osp.join("./results", args["log_root"]) + "/"
    else:
        log_root = osp.join("./results/")

    my_experiment = experiment(
        args["name"],
        args,
        log_root,
        commit_changes=False,
        rank=0,
        seed=args["seed"],
    )
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    gpu_to_use = rank % args["gpus"]
    if torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpu_to_use))
        logger.info("Using gpu : %s", "cuda:" + str(gpu_to_use))
    else:
        device = torch.device("cpu")

    print("Train dataset")
    dataset = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        all=True,
        resize=args["resize"],
        augment=args["augment"],
        prefetch_gpu=args["prefetch_gpu"],
    )
    print("Val dataset")
    val_dataset = df.DatasetFactory.get_dataset(
        args["dataset"],
        background=True,
        train=True,
        path=args["path"],
        all=True,
        resize=args["resize"],
        prefetch_gpu=args["prefetch_gpu"],
        #  augment=args["augment"],
    )

    train_labels = np.arange(664)
    # class_labels = np.array(dataset.targets)
    class_labels = np.array(np.asarray(torch.as_tensor(dataset.targets, device="cpu")))
    labels_mapping = {
        tl: (class_labels == tl).astype(int).nonzero()[0] for tl in train_labels
    }
    train_indices = [tl[:15] for tl in labels_mapping.values()]
    val_indices = [tl[15:] for tl in labels_mapping.values()]
    train_indices = [i for sublist in train_indices for i in sublist]
    val_indices = [i for sublist in val_indices for i in sublist]

    # indices = np.zeros_like(class_labels)
    # for a in train_labels:
    #     indices = indices + (class_labels == a).astype(int)
    # val_indices = (indices == 0).astype(int)
    # indices = np.nonzero(indices)[0]
    trainset = torch.utils.data.Subset(dataset, train_indices)

    # print(indices)
    print("Total samples:", len(class_labels))
    print("Train samples:", len(train_indices))
    print("Val samples:", len(val_indices))

    #  val_labels = np.arange(664)
    # class_labels = np.array(dataset.targets)
    # val_indices = np.zeros_like(class_labels)
    # for a in train_labels:
    #     val_indices = val_indices + (class_labels != a).astype(int)
    # val_indices = np.nonzero(val_indices)[0]
    valset = torch.utils.data.Subset(val_dataset, val_indices)

    train_iterator = torch.utils.data.DataLoader(
        trainset,
        batch_size=64,
        shuffle=True,
        num_workers=0,
        drop_last=True,
    )
    val_iterator = torch.utils.data.DataLoader(
        valset,
        batch_size=256,
        shuffle=True,
        num_workers=0,
        drop_last=False,
    )

    logger.info("Args:")
    logger.info(str(vars(_args)))
    logger.info(str(args))

    config = mf.ModelFactory.get_model("na", args["dataset"], resize=args["resize"])

    maml = learner.Learner(config).to(device)

    for k, v in maml.named_parameters():
        print(k, v.requires_grad)

    # opt = torch.optim.Adam(maml.parameters(), lr=args["lr"])
    opt = torch.optim.SGD(
        maml.parameters(),
        lr=args["lr"],
        momentum=0.9,
        weight_decay=5e-4,
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        opt,
        milestones=_args.schedule,
        gamma=0.1,
    )

    best_val_acc = 0

    # print(learner)
    # print(learner.eval(False))

    histories = {
        "train": {"acc": [], "loss": [], "step": []},
        "val": {"acc": [], "loss": [], "step": []},
    }

    for e in range(args["epoch"]):
        correct = 0
        total_loss = 0.0
        maml.train()
        for img, y in tqdm(train_iterator):
            img = img.to(device)
            y = y.to(device)
            pred = maml(img)

            opt.zero_grad()
            loss = F.cross_entropy(pred, y.long())
            loss.backward()
            opt.step()
            correct += (pred.argmax(1) == y).float().mean()
            total_loss += loss
        correct = correct.item()
        total_loss = total_loss.item()
        scheduler.step()

        val_correct = 0
        val_total_loss = 0.0
        maml.eval()
        for img, y in tqdm(val_iterator):
            img = img.to(device)
            y = y.to(device)
            with torch.no_grad():
                pred = maml(img)

                opt.zero_grad()
                loss = F.cross_entropy(pred, y.long())
                # loss.backward()
                # opt.step()
                val_correct += (pred.argmax(1) == y).sum().float()
                val_total_loss += loss * y.size(0)
        val_correct = val_correct.item()
        val_total_loss = val_total_loss.item()
        val_acc = val_correct / len(val_indices)
        val_loss = val_total_loss / len(val_indices)

        train_correct = correct / len(train_iterator)
        train_loss = total_loss / len(train_iterator)

        logger.info("Accuracy at epoch %d = %s", e, str(train_correct))
        logger.info("Loss at epoch %d = %s", e, str(train_loss))
        logger.info("Val Accuracy at epoch %d = %s", e, str(val_acc))
        logger.info("Val Loss at epoch %d = %s", e, str(val_loss))

        histories["train"]["acc"].append(train_correct)
        histories["train"]["loss"].append(train_loss)
        histories["val"]["acc"].append(val_acc)
        histories["val"]["loss"].append(val_loss)
        histories["train"]["step"].append(e + 1)
        histories["val"]["step"].append(e + 1)

        writer.add_scalar(
            "/train/accuracy",
            train_correct,
            e + 1,
        )
        writer.add_scalar(
            "/train/loss",
            train_loss,
            e + 1,
        )
        writer.add_scalar(
            "/val/accuracy",
            val_acc,
            e + 1,
        )
        writer.add_scalar(
            "/train/loss",
            val_loss,
            e + 1,
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            logger.info(f"\nNew best validation accuracy: {str(best_val_acc)}\n")
            torch.save(maml, my_experiment.path + "model_best.net")

    with open(my_experiment.path + "results.json", "w") as f:
        json.dump(histories, f)
    torch.save(maml, my_experiment.path + "last_model.net")