コード例 #1
0
def test(test_data,
         model,
         args,
         num_episodes,
         verbose=True,
         sampled_tasks=None):
    '''
        Evaluate the model on a bag of sampled tasks. Return the mean accuracy
        and its std.
    '''
    # clone the original model
    fast_model = {
        'ebd': copy.deepcopy(model['ebd']),
        'clf': copy.deepcopy(model['clf']),
    }

    if sampled_tasks is None:
        sampled_tasks = ParallelSampler(test_data, args,
                                        num_episodes).get_epoch()

    acc = []

    sampled_tasks = enumerate(sampled_tasks)
    if not args.notqdm:
        sampled_tasks = tqdm(sampled_tasks,
                             total=num_episodes,
                             ncols=80,
                             leave=False,
                             desc=colored('Testing on val', 'yellow'))

    for i, task in sampled_tasks:
        if i == num_episodes and not args.notqdm:
            sampled_tasks.close()
            break
        _copy_weights(model['ebd'], fast_model['ebd'])
        _copy_weights(model['clf'], fast_model['clf'])
        acc.append(test_one(task, fast_model, args))

    acc = np.array(acc)

    if verbose:
        print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            colored("acc mean", "blue"),
            np.mean(acc),
            colored("std", "blue"),
            np.std(acc),
        ))

    return np.mean(acc), np.std(acc)
コード例 #2
0
def test(test_data, model, args, num_episodes, verbose=True, sampled_tasks=None):
    '''
        Evaluate the model on a bag of sampled tasks. Return the mean accuracy
        and its std.
    '''
    model['ebd'].eval()
    model['clf'].eval()

    if sampled_tasks is None:
        sampled_tasks = ParallelSampler(test_data, args,
                                        num_episodes).get_epoch()

    acc = []
    if not args.notqdm:
        sampled_tasks = tqdm(sampled_tasks, total=num_episodes, ncols=80,
                             leave=False,
                             desc=colored('Testing on val', 'yellow'))

    for task in sampled_tasks:
        acc.append(test_one(task, model, args))

    acc = np.array(acc)

    if verbose:
        print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                colored("acc mean", "blue"),
                np.mean(acc),
                colored("std", "blue"),
                np.std(acc),
                ), flush=True)

    return np.mean(acc), np.std(acc)
コード例 #3
0
def train(train_data, val_data, model, args):
    '''
        Train the model (obviously~)
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(
        os.path.join(os.path.curdir, "tmp-runs", str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    opt = torch.optim.Adam(grad_param(model, ['ebd', 'clf']), lr=args.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
                                                           'max',
                                                           patience=5,
                                                           factor=0.1,
                                                           verbose=True)

    # clone the original model
    fast_model = {
        'ebd': copy.deepcopy(model['ebd']),
        'clf': copy.deepcopy(model['clf']),
    }

    print("{}, Start training".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')))

    train_gen = ParallelSampler(train_data, args,
                                args.train_episodes * args.maml_batchsize)
    val_gen = ParallelSampler(val_data, args, args.val_episodes)
    for ep in range(args.train_epochs):
        sampled_tasks = train_gen.get_epoch()

        meta_grad_dict = {'clf': [], 'ebd': []}

        train_episodes = range(args.train_episodes)
        if not args.notqdm:
            train_episodes = tqdm(train_episodes,
                                  ncols=80,
                                  leave=False,
                                  desc=colored('Training on train', 'yellow'))

        for _ in train_episodes:
            # update the initialization based on a batch of tasks
            total_grad = {'ebd': [], 'clf': []}

            for _ in range(args.maml_batchsize):
                task = next(sampled_tasks)

                # clone the current initialization
                _copy_weights(model['ebd'], fast_model['ebd'])
                _copy_weights(model['clf'], fast_model['clf'])

                # get the meta gradient
                train_one(task, fast_model, args, total_grad)

            ebd_grad, clf_grad = _meta_update(model, total_grad, opt, task,
                                              args.maml_batchsize,
                                              args.clip_grad)
            meta_grad_dict['ebd'].append(ebd_grad)
            meta_grad_dict['clf'].append(clf_grad)

        # evaluate training accuracy
        if ep % 10 == 0:
            acc, std = test(train_data, model, args,
                            args.train_episodes * args.maml_batchsize, False,
                            train_gen.get_epoch())
            print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                "ep",
                ep,
                colored("train", "red"),
                colored("acc:", "blue"),
                acc,
                std,
            ),
                  flush=True)

        # evaluate validation accuracy
        cur_acc, cur_std = test(val_data, model, args, args.val_episodes,
                                False, val_gen.get_epoch())
        print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} "
               "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
                   datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                   "ep", ep, colored("val  ", "cyan"), colored("acc:", "blue"),
                   cur_acc, cur_std, colored("train stats", "cyan"),
                   colored("ebd_grad:", "blue"),
                   np.mean(np.array(meta_grad_dict['ebd'])),
                   colored("clf_grad:", "blue"),
                   np.mean(np.array(meta_grad_dict['clf']))),
              flush=True)

        # Update the current best model if val acc is better
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_path = os.path.join(out_dir, str(ep))

            # save current model
            print("{}, Save cur best model to {}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                best_path))

            torch.save(model['ebd'].state_dict(), best_path + '.ebd')
            torch.save(model['clf'].state_dict(), best_path + '.clf')

            sub_cycle = 0
        else:
            sub_cycle += 1

        # Break if the val acc hasn't improved in the past patience epochs
        if sub_cycle == args.patience:
            break

    print("{}, End of training. Restore the best weights".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')))

    # restore the best saved model
    model['ebd'].load_state_dict(torch.load(best_path + '.ebd'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "saved-runs",
                         str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            best_path),
              flush=True)

        torch.save(model['ebd'].state_dict(), best_path + '.ebd')
        torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return
コード例 #4
0
def train(train_data, val_data, model, args):
    '''
        Train the model
        Use val_data to do early stopping

        Args:
            model (dict): {'ebd': embedding, 'clf': classifier}
    '''
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(os.path.join(
                                  os.path.curdir,
                                  "tmp-runs",
                                  str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Write results
    # write_acc_tr = 'acc_base.csv'
    # init_csv(write_acc_tr)
    # write_acc_val = 'val_acc_base.csv'
    # init_csv(write_acc_val)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    # grad_param generates the learnable parameters from the classifier
    params_to_opt = grad_param(model, ['ebd', 'clf'])
    opt = torch.optim.Adam(params_to_opt, lr=args.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, 'max', patience=args.patience//2, factor=0.1, verbose=True)

    print("{}, Start training".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True)

    train_gen = ParallelSampler(train_data, args, args.train_episodes)
    train_gen_val = ParallelSampler(train_data, args, args.val_episodes)
    val_gen = ParallelSampler(val_data, args, args.val_episodes)

    for ep in range(args.train_epochs):
        sampled_tasks = train_gen.get_epoch()

        grad = {'clf': [], 'ebd': []}

        if not args.notqdm:
            sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes,
                    ncols=80, leave=False, desc=colored('Training on train',
                        'yellow'))

        for task in sampled_tasks:
            if task is None:
                break
            train_one(task, model, opt, args, grad)

        if ep % 10 == 0:
            acc, std = test(train_data, model, args, args.val_episodes, False,
                            train_gen_val.get_epoch())
            print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                "ep", ep,
                colored("train", "red"),
                colored("acc:", "blue"), acc, std,
                ), flush=True)

            # write_csv(write_acc_tr, acc, std, ep)

        # Evaluate validation accuracy
        cur_acc, cur_std = test(val_data, model, args, args.val_episodes, False,
                                val_gen.get_epoch())
        print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, "
               "{:s} {:s}{:>7.4f}, {:s}{:>7.4f}").format(
               datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
               "ep", ep,
               colored("val  ", "cyan"),
               colored("acc:", "blue"), cur_acc, cur_std,
               colored("train stats", "cyan"),
               colored("ebd_grad:", "blue"), np.mean(np.array(grad['ebd'])),
               colored("clf_grad:", "blue"), np.mean(np.array(grad['clf'])),
               ), flush=True)

        # if ep % 10 == 0: write_csv(write_acc_val, cur_acc, cur_std, ep)

        # Update the current best model if val acc is better
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_path = os.path.join(out_dir, str(ep))

            # save current model
            print("{}, Save cur best model to {}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                best_path))

            torch.save(model['ebd'].state_dict(), best_path + '.ebd')
            torch.save(model['clf'].state_dict(), best_path + '.clf')

            sub_cycle = 0
        else:
            sub_cycle += 1

        # Break if the val acc hasn't improved in the past patience epochs
        if sub_cycle == args.patience:
            break

    print("{}, End of training. Restore the best weights".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')),
            flush=True)

    # restore the best saved model
    model['ebd'].load_state_dict(torch.load(best_path + '.ebd'))
    model['clf'].load_state_dict(torch.load(best_path + '.clf'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(os.path.join(
                                      os.path.curdir,
                                      "saved-runs",
                                      str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            best_path), flush=True)

        torch.save(model['ebd'].state_dict(), best_path + '.ebd')
        torch.save(model['clf'].state_dict(), best_path + '.clf')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return