def train(config_files, run_number):
    max_accuracy = []
    max_validation_accuracy = []
    for n in range(1, 2):
        X, y = get_data_train(n, config_files, run_number)
        input_size, hidden_size, output_size = X.shape[1], 16, 8
        model = MLP(input_size, hidden_size, output_size)
        model.to(device)
        X, y = X.to(device), y.to(device)
        epochs = 20
        accuracy = []
        test_accuracy = []
        for i in range(epochs):
            output_i, loss = train_optim(model, y, X)
            print("epoch {}".format(i))
            print("accuracy = ", np.sum(output_i == y.cpu().numpy()) / y.size())
            print("loss: {}".format(loss))
            accuracy.append((np.sum(output_i == y.cpu().numpy()) / y.size())[0])
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(model.state_dict(), "checkpoint/MLP_model_{}_train.pwf".format(i))
            test_accuracy.append(validate(n, config_files, run_number, model))
            torch.save(model.state_dict(), "checkpoint/MLP_model_{}_validate.pwf".format(i))

        plot_accuracy_n_print(accuracy, max_accuracy, n, run_number, 'train')
        plot_accuracy_n_print(test_accuracy, max_validation_accuracy, n, run_number, 'validate')
示例#2
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # # Build data loader
    # dataset,targets= load_dataset()
    # np.save("__cache_dataset.npy", dataset)
    # np.save("__cache_targets.npy", targets)
    # return

    dataset = np.load("__cache_dataset.npy")
    targets = np.load("__cache_targets.npy")

    # Build the models
    mlp = MLP(args.input_size, args.output_size)

    mlp.load_state_dict(
        torch.load(
            '_backup_model_statedict/mlp_100_4000_PReLU_ae_dd_final.pkl'))

    if torch.cuda.is_available():
        mlp.cuda()

    # Loss and Optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adagrad(mlp.parameters())

    # Train the Models
    total_loss = []
    print(len(dataset))
    print(len(targets))
    sm = 100  # start saving models after 100 epochs
    for epoch in range(args.num_epochs):
        print("epoch" + str(epoch))
        avg_loss = 0
        for i in range(0, len(dataset), args.batch_size):
            # Forward, Backward and Optimize
            mlp.zero_grad()
            bi, bt = get_input(i, dataset, targets, args.batch_size)
            bi = to_var(bi)
            bt = to_var(bt)
            bo = mlp(bi)
            loss = criterion(bo, bt)
            avg_loss = avg_loss + loss.item()
            loss.backward()
            optimizer.step()
        print("--average loss:")
        print(avg_loss / (len(dataset) / args.batch_size))
        total_loss.append(avg_loss / (len(dataset) / args.batch_size))
        # Save the models
        if epoch == sm:
            model_path = 'mlp_100_4000_PReLU_ae_dd' + str(sm) + '.pkl'
            torch.save(mlp.state_dict(),
                       os.path.join(args.model_path, model_path))
            sm = sm + 50  # save model after every 50 epochs from 100 epoch ownwards
    torch.save(total_loss, 'total_loss.dat')
    model_path = 'mlp_100_4000_PReLU_ae_dd_final.pkl'
    torch.save(mlp.state_dict(), os.path.join(args.model_path, model_path))
示例#3
0
def gpu_thread(load, memory_queue, process_queue, common_dict, worker):
    # the only thread that has an access to the gpu, it will then perform all the NN computation
    import psutil
    p = psutil.Process()
    p.cpu_affinity([worker])
    import signal
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    try:
        print('process started with pid: {} on core {}'.format(
            os.getpid(), worker),
              flush=True)
        model = MLP(parameters.OBS_SPACE, parameters.ACTION_SPACE)
        model.to(parameters.DEVICE)
        # optimizer = optim.Adam(model.parameters(), lr=5e-5)
        # optimizer = optim.SGD(model.parameters(), lr=3e-2)
        optimizer = optim.RMSprop(model.parameters(), lr=1e-4)
        epochs = 0
        if load:
            checkpoint = torch.load('./model/walker.pt')
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epochs = checkpoint['epochs']
        observations = torch.Tensor([]).to(parameters.DEVICE)
        rewards = torch.Tensor([]).to(parameters.DEVICE)
        actions = torch.Tensor([]).to(parameters.DEVICE)
        probs = torch.Tensor([]).to(parameters.DEVICE)
        common_dict['epoch'] = epochs
        while True:
            memory_full, observations, rewards, actions, probs = \
                destack_memory(memory_queue, observations, rewards, actions, probs)
            destack_process(model, process_queue, common_dict)
            if len(observations) > parameters.MAXLEN or memory_full:
                epochs += 1
                print('-' * 60 + '\n        epoch ' + str(epochs) + '\n' +
                      '-' * 60)
                run_epoch(epochs, model, optimizer, observations, rewards,
                          actions, probs)
                observations = torch.Tensor([]).to(parameters.DEVICE)
                rewards = torch.Tensor([]).to(parameters.DEVICE)
                actions = torch.Tensor([]).to(parameters.DEVICE)
                probs = torch.Tensor([]).to(parameters.DEVICE)
                torch.save(
                    {
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'epochs': epochs
                    }, './model/walker.pt')
                common_dict['epoch'] = epochs
    except Exception as e:
        print(e)
        print('saving before interruption', flush=True)
        torch.save(
            {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epochs': epochs
            }, './model/walker.pt')
def main():
    parser = setup_parser()
    args = parser.parse_args()
    subprocess.run(f"mkdir {args.model}", shell=True)
    torch.manual_seed(42)
    # device="cuda" if args.gpus == -1 else "cpu"

    k_data_loaders = create_k_splitted_data_loaders(args)
    if args.model == "MLP":
        model = MLP(args)
    elif args.model == "CNN":
        model = CNN(args)
    model.apply(reset_weights)

    acc_results, logs = [], []
    for fold, train_loader, test_loader in k_data_loaders:
        print(f"FOLD {fold}\n-----------------------------")
        print("Starting training...")
        model, log = train_loop(train_loader, model, args)
        logs.append(log)

        print("Training process has finished. Saving trained model.")
        torch.save(model.state_dict(), f"./{args.model}/model_fold_{fold}.pth")

        print("Starting testing...")
        correct_rate = test_loop(test_loader, model, fold)
        acc_results.append(correct_rate)

        print("Resetting the model weights...")
        reset_weights(model)

    print(
        f"K-FOLD CROSS VALIDATION RESULTS FOR {args.k_folds} FOLDS\n----------------------"
    )
    print(f"Average: {sum(acc_results) / len(acc_results):.3g}%")
示例#5
0
def main(dataset, dim, layers, lr, reg, epochs, batchsize):
    n_user = overlap_user(dataset)
    print(n_user)
    logging.info(str(n_user))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    mf_s, mf_t = load_model(dataset, dim)
    mapping = MLP(dim, layers)
    mf_s = mf_s.to(device)
    mf_t = mf_t.to(device)
    mapping = mapping.to(device)
    opt = torch.optim.Adam(mapping.parameters(), lr=lr, weight_decay=reg)
    mse_loss = nn.MSELoss()

    start = time()
    for epoch in range(epochs):
        loss_sum = 0
        for users in batch_user(n_user, batchsize):
            us = torch.tensor(users).long()
            us = us.to(device)
            u = mf_s.get_embed(us)
            y = mf_t.get_embed(us)
            loss = train(mapping, opt, mse_loss, u, y)
            loss_sum += loss
        print('Epoch %d [%.1f] loss = %f' % (epoch, time()-start, loss_sum))
        logging.info('Epoch %d [%.1f] loss = %f' %
                     (epoch, time()-start, loss_sum))
        start = time()

    mfile = 'pretrain/%s/Mapping.pth.tar' % dataset
    torch.save(mapping.state_dict(), mfile)
    print('save [%.1f]' % (time()-start))
    logging.info('save [%.1f]' % (time()-start))
示例#6
0
def train_net(net,
              device,
              epochs=100,
              batch_size=32,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              ):
    
def get_args():
    parser = argparse.ArgumentParser(description='Train the PointRender on images and pre-processed labels',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
                        help='Batch size', dest='batchsize')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
                        help='Learning rate', dest='lr')
    parser.add_argument('-f', '--load', dest='load', type=str, default=False,
                        help='Load model from a .pth file')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    net = MLP(input_voxel = 1, n_classes = 3)

    if args.load:
        net.lead_state_dict(
            torch.load(args.load, map_location=device)
        )

    net.to(device=device)

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  val_percent=args.val / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
示例#7
0
def on_message(client, userdata, msg):
    try:
        if 'loss' in msg.topic:
            logger.info("Loss from trainer received!")
            #logger.info('Topic: ', msg.topic)
            logger.info(f'Topic: {str(msg.topic)}')
            global trainer_losses
            trainer_losses.append(float(msg.payload))

            if len(trainer_losses) == NUM_TRAINERS:
                losses.append(np.average(trainer_losses))
                trainer_losses.clear()
        else:
            logger.info("Model from trainer received!")
            #logger.info('Topic: ', msg.topic)
            #logger.info('Message: ', msg.payload)
            logger.info(f'Topic: {str(msg.topic)}')

            model_str = msg.payload
            buff = io.BytesIO(bytes(model_str))

            # Create a dummy model to read weights
            input_size = 78
            model = MLP(input_size)
            model.load_state_dict(torch.load(buff))

            global trainer_weights
            trainer_weights.append(copy.deepcopy(model.state_dict()))

            # Wait until we get trained weights from all trainers
            if len(trainer_weights) == NUM_TRAINERS:
                update_global_weights_and_send(trainer_weights)
                trainer_weights.clear()

    except:
        logger.info(f"Unexpected error: {str(sys.exc_info())}")
示例#8
0
def train(args):

    perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data(path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis")

    total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad")
    trainset, testset = preprocessing.make_total_data(total,sgRNA_list_mock)

    TrainSet = perturbdataloader(trainset, ways = args.num_ways, support_shots = args.num_shots, query_shots = 15)
    TrainLoader = DataLoader(TrainSet, batch_size=args.batch_size_train, shuffle=False,num_workers=args.num_workers)

    model = MLP(out_features = args.num_ways)

    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(TrainLoader, total=args.num_batches) as pbar:
        for batch_idx, (inputs_support, inputs_query, target_support, target_query) in enumerate(pbar):
            model.zero_grad()

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            inputs_query = inputs_query.to(device=args.device)
            target_query = target_query.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(zip(inputs_support, target_support,inputs_query, target_query)):

                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size_train)
            accuracy.div_(args.batch_size_train)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches or accuracy.item() > 0.95:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'maml_omniglot_'
                                                    '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)

    # start test
    test_support, test_query, test_target_support, test_target_query \
        = helpfuntions.sample_once(testset,support_shot=args.num_shots, shuffle=False,plus = len(trainset))
    test_query = torch.from_numpy(test_query).to(device=args.device)
    test_target_query = torch.from_numpy(test_target_query).to(device=args.device)

    TrainSet = perturbdataloader_test(test_support, test_target_support)
    TrainLoader = DataLoader(TrainSet, args.batch_size_test)

    meta_optimizer.zero_grad()
    inner_losses = []
    accuracy_test = []

    for epoch in range(args.num_epoch):
        model.to(device=args.device)
        model.train()

        for _, (inputs_support,target_support) in enumerate(TrainLoader):

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            train_logit = model(inputs_support)
            loss = F.cross_entropy(train_logit, target_support)
            inner_losses.append(loss)
            loss.backward()
            meta_optimizer.step()
            meta_optimizer.zero_grad()

            test_logit = model(test_query)
            with torch.no_grad():
                accuracy = get_accuracy(test_logit, test_target_query)
                accuracy_test.append(accuracy)



        if (epoch + 1) % 3 == 0:
            print('Epoch [{}/{}], Loss: {:.4f},accuray: {:.4f}'.format(epoch + 1, args.num_epoch, loss,accuracy))
    '''Step 3: Train the Model'''
    print('Training begins: ')

    global_acc = 0
    for epoch in range(num_epoch):
        epoch = epoch+1
        print(f'Epoch {epoch} starts:')
        train_start = time.time()
        train_loss, train_acc = Train(
            train_dataloader,
            model,
            criterion,
            optimizer
        )
        train_end = time.time()

        print(f"Epoch {epoch} completed in: {train_end-train_start}s \t Loss: {train_loss} \t Acc: {train_acc}")

        val_start = time.time()
        val_loss, val_acc = Val(
            val_dataloader,
            model,
            criterion
        )
        val_end = time.time()
        if val_acc > global_acc:
            torch.save(model.state_dict(), f"./{epoch}_model8.pth.tar")
            global_acc = val_acc
        print(f"Validation Loss: {val_loss} \t Validation Acc: {val_acc}")
示例#10
0
        # evaluate model
        test_loss = eval_model(model,
                               criterion,
                               session_test_loader_list,
                               task_id_dict=task_id_dict,
                               device=device)

        # wrap test loss in list (in case a single task is evaluated a single numerical value is returned)
        if (len(session_test_loader_list) == 1):
            test_loss = [test_loss]

        # store results/stats and base network (to be reused for all successive lamda/gamma variations)
        if (args.save_models):
            base_stats = copy.deepcopy(
                [session, train_loss, val_loss, test_loss,
                 model.state_dict()])
        else:
            base_stats = copy.deepcopy(
                [session, train_loss, val_loss, test_loss])

            # save reference model internally or externally
            if (args.offload_aux_models):
                model.store(model_path_base)
            else:
                model_base_state_dict = copy.deepcopy(model.state_dict())

        ######################
        #     increments     #
        ######################

        for lamb in args.lambda_list:
示例#11
0
def train_model(config, gpu_id, save_dir, exp_name):
    # Instantiating the model
    model_type = config.get('model_type', 'MLP')
    if model_type == "MLP":
        model = MLP(784, config["hidden_layers"], 10, config["nonlinearity"], config["initialization"], config["dropout"], verbose=True)
    elif model_type == "CNN":
        model = CNN(config["initialization"], config["is_batch_norm"], verbose=True)
    else:
        raise ValueError('config["model_type"] not supported : {}'.format(model_type))

    # Loading the MNIST dataset
    x_train, y_train, x_valid, y_valid, x_test, y_test = utils.load_mnist(config["data_file"], data_format=config["data_format"])

    if config['data_reduction'] != 1.:
        x_train, y_train = utils.reduce_trainset_size(x_train, y_train, config['data_reduction'])

    # If GPU is available, sends model and dataset on the GPU
    if torch.cuda.is_available():
        model.cuda(gpu_id)

        x_train = torch.from_numpy(x_train).cuda(gpu_id)
        y_train = torch.from_numpy(y_train).cuda(gpu_id)

        x_valid = Variable(torch.from_numpy(x_valid), volatile=True).cuda(gpu_id)
        y_valid = Variable(torch.from_numpy(y_valid), volatile=True).cuda(gpu_id)

        x_test = Variable(torch.from_numpy(x_test), volatile=True).cuda(gpu_id)
        y_test = Variable(torch.from_numpy(y_test), volatile=True).cuda(gpu_id)
        print("Running on GPU")
    else:
        x_train = torch.from_numpy(x_train)
        y_train = torch.from_numpy(y_train)

        x_valid = Variable(torch.from_numpy(x_valid))
        y_valid = Variable(torch.from_numpy(y_valid))

        x_test = Variable(torch.from_numpy(x_test))
        y_test = Variable(torch.from_numpy(y_test))
        print("WATCH-OUT : torch.cuda.is_available() returned False. Running on CPU.")

    # Instantiate TensorDataset and DataLoader objects
    train_set = torch.utils.data.TensorDataset(x_train, y_train)
    loader = torch.utils.data.DataLoader(train_set, batch_size=config["mb_size"], shuffle=True)

    # Optimizer and Loss Function
    optimizer = optim.SGD(model.parameters(), lr=config['lr'],
                                              momentum=config['momentum'],
                                              weight_decay=config['L2_hyperparam'] * (config['mb_size'] / x_train.size()[0]))
    loss_fn = nn.NLLLoss()

    # Records the model's performance
    train_tape = [[],[]]
    valid_tape = [[],[]]
    test_tape = [[],[]]
    weights_tape = []

    def evaluate(data, labels):

        model.eval()
        if not isinstance(data, Variable):
            if torch.cuda.is_available():
                data = Variable(data, volatile=True).cuda(gpu_id)
                labels = Variable(labels, volatile=True).cuda(gpu_id)
            else:
                data = Variable(data)
                labels = Variable(labels)

        output = model(data)
        loss = loss_fn(output, labels)
        prediction = torch.max(output.data, 1)[1]
        accuracy = (prediction.eq(labels.data).sum() / labels.size(0)) * 100

        return loss.data[0], accuracy

    if not os.path.exists(os.path.join(save_dir, exp_name)):
        os.makedirs(os.path.join(save_dir, exp_name))

    # Record train accuracy
    train_loss, train_acc = evaluate(x_train, y_train)
    train_tape[0].append(train_loss)
    train_tape[1].append(train_acc)

    # Record valid accuracy
    valid_loss, valid_acc = evaluate(x_valid, y_valid)
    valid_tape[0].append(valid_loss)
    valid_tape[1].append(valid_acc)

    # Record test accuracy
    test_loss, test_acc = evaluate(x_test, y_test)
    test_tape[0].append(test_loss)
    test_tape[1].append(test_acc)

    # Record weights L2 norm
    weights_L2_norm = model.get_weights_L2_norm()
    weights_tape.append(float(weights_L2_norm.data.cpu().numpy()))

    print("BEFORE TRAINING \nLoss : {0:.3f} \nAcc : {1:.3f}".format(valid_loss, valid_acc))

    # TRAINING LOOP
    best_valid_acc = 0
    for epoch in range(1, config["max_epochs"]):
        start = time.time()
        model.train()
        for i,(x_batch, y_batch) in enumerate(loader):

            #pdb.set_trace()

            if torch.cuda.is_available():
                x_batch = Variable(x_batch).cuda(gpu_id)
                y_batch = Variable(y_batch).cuda(gpu_id)
            else:
                x_batch = Variable(x_batch)
                y_batch = Variable(y_batch)

            # Empties the gradients
            optimizer.zero_grad()

            # Feedforward through the model
            output = model(x_batch)

            # Computes the loss
            loss = loss_fn(output, y_batch)

            # Backpropagates to compute the gradients
            loss.backward()

            # Takes one training step
            optimizer.step()

            # Record weights L2 norm
            weights_L2_norm = model.get_weights_L2_norm()
            weights_tape.append(float(weights_L2_norm.data.cpu().numpy()))

        # Record train accuracy
        train_loss, train_acc = evaluate(x_train, y_train)
        train_tape[0].append(train_loss)
        train_tape[1].append(train_acc)

        # Record valid accuracy
        valid_loss, valid_acc = evaluate(x_valid, y_valid)
        valid_tape[0].append(valid_loss)
        valid_tape[1].append(valid_acc)

        # Record test accuracy
        test_loss, test_acc = evaluate(x_test, y_test)
        test_tape[0].append(test_loss)
        test_tape[1].append(test_acc)

        print("Epoch {0} \nLoss : {1:.3f} \nAcc : {2:.3f}".format(epoch, valid_loss, valid_acc))
        print("Time : {0:.2f}".format(time.time() - start))

        # Saves the model
        if valid_acc > best_valid_acc:
            print("NEW BEST MODEL")
            torch.save(model.state_dict(), os.path.join(save_dir, exp_name, "model"))
            best_valid_acc = valid_acc

    # Saves the graphs
    utils.save_results(train_tape, valid_tape, test_tape, weights_tape, save_dir, exp_name, config)
    utils.update_comparative_chart(save_dir, config['show_test'])

    return
示例#12
0
class Trainer():
    def __init__(self, config_path):
        config = configparser.ConfigParser()
        config.read(config_path)

        self.n_epoch = config.getint("general", "n_epoch")
        self.batch_size = config.getint("general", "batch_size")
        self.train_bert = config.getboolean("general", "train_bert")
        self.lr = config.getfloat("general", "lr")
        self.cut_frac = config.getfloat("general", "cut_frac")
        self.log_dir = Path(config.get("general", "log_dir"))
        if not self.log_dir.exists():
            self.log_dir.mkdir(parents=True)
        self.model_save_freq = config.getint("general", "model_save_freq")

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # bert_config_path = config.get("bert", "config_path")
        # bert_tokenizer_path = config.get("bert", "tokenizer_path")
        # bert_model_path = config.get("bert", "model_path")

        self.bert_tokenizer = LongformerTokenizer.from_pretrained(
            'allenai/longformer-base-4096')
        # self.bert_tokenizer = BertTokenizer.from_pretrained(bert_tokenizer_path)
        tkzer_save_dir = self.log_dir / "tokenizer"
        if not tkzer_save_dir.exists():
            tkzer_save_dir.mkdir()
        self.bert_tokenizer.save_pretrained(tkzer_save_dir)
        self.bert_model = LongformerModel.from_pretrained(
            'allenai/longformer-base-4096')
        self.bert_config = self.bert_model.config
        # self.bert_config = BertConfig.from_pretrained(bert_config_path)
        # self.bert_model = BertModel.from_pretrained(bert_model_path, config=self.bert_config)
        self.max_seq_length = self.bert_config.max_position_embeddings - 2
        # self.max_seq_length = self.bert_config.max_position_embeddings
        self.bert_model.to(self.device)

        if self.train_bert:
            self.bert_model.train()
        else:
            self.bert_model.eval()

        train_conll_path = config.get("data", "train_path")
        print("train path", train_conll_path)
        assert Path(train_conll_path).exists()
        dev_conll_path = config.get("data", "dev_path")
        print("dev path", dev_conll_path)
        assert Path(dev_conll_path).exists()
        dev1_conll_path = Path(dev_conll_path) / "1"
        print("dev1 path", dev1_conll_path)
        assert dev1_conll_path.exists()
        dev2_conll_path = Path(dev_conll_path) / "2"
        print("dev2 path", dev2_conll_path)
        assert dev2_conll_path.exists()
        self.train_dataset = ConllDataset(train_conll_path)
        # self.dev_dataset = ConllDataset(dev_conll_path)
        self.dev1_dataset = ConllDataset(dev1_conll_path)
        self.dev2_dataset = ConllDataset(dev2_conll_path)
        if self.batch_size == -1:
            self.batch_size = len(self.train_dataset)

        self.scaler = torch.cuda.amp.GradScaler()
        tb_cmt = f"lr_{self.lr}_cut-frac_{self.cut_frac}"
        self.writer = SummaryWriter(log_dir=self.log_dir, comment=tb_cmt)

    def transforms(self, example, label_list):
        feature = convert_single_example(example, label_list,
                                         self.max_seq_length,
                                         self.bert_tokenizer)
        label_ids = feature.label_ids
        label_map = feature.label_map
        gold_labels = [-1] * self.max_seq_length
        # Get "Element" or "Main" token indices
        for i, lid in enumerate(label_ids):
            if lid == label_map['B-Element']:
                gold_labels[i] = 0
            elif lid == label_map['B-Main']:
                gold_labels[i] = 1
            elif lid in (label_map['I-Element'], label_map['I-Main']):
                gold_labels[i] = 2
            elif lid == label_map['X']:
                gold_labels[i] = 3
        # flush data to bert model
        input_ids = torch.tensor(feature.input_ids).unsqueeze(0).to(
            self.device)
        if self.train_bert:
            model_output = self.bert_model(input_ids)
        else:
            with torch.no_grad():
                model_output = self.bert_model(input_ids)

        # lstm (ignore padding parts)
        model_fv = model_output[0]
        input_ids = torch.tensor(feature.input_ids)
        label_ids = torch.tensor(feature.label_ids)
        gold_labels = torch.tensor(gold_labels)
        return model_fv, input_ids, label_ids, gold_labels

    @staticmethod
    def extract_tokens(fv, gold_labels):
        ents, golds = [], []
        ents_mask = [-1] * len(gold_labels)
        ent, gold, ent_id = [], None, 0
        ent_flag = False
        for i, gt in enumerate(gold_labels):
            if gt == 2:  # in case of "I-xxx"
                ent.append(fv[i, :])
                ents_mask[i] = ent_id
                ent_end = i
            elif gt == 3 and ent_flag:  # in case of "X"
                ent.append(fv[i, :])
                ents_mask[i] = ent_id
                ent_end = i
            elif ent:
                ents.append(ent)
                golds.append(gold)
                ent = []
                ent_id += 1
                ent_flag = False
            if gt in (0, 1):  # in case of "B-xxx"
                ent.append(fv[i, :])
                gold = gt
                ents_mask[i] = ent_id
                ent_start = i
                ent_flag = True
        else:
            if ent:
                ents.append(ent)
                golds.append(gold)
        return ents, golds, ents_mask

    def eval(self, dataset):
        tp, fp, tn, fn = 0, 0, 0, 0
        with torch.no_grad():
            for data in tqdm(dataset):
                # flush to Bert
                fname, example = data

                try:
                    fvs, input_ids, label_ids, gold_labels = self.transforms(
                        example, dataset.label_list)
                except RuntimeError:
                    print(f"{fname} cannot put in memory!")
                    continue

                # extract Element/Main tokens
                ents, ent_golds, _ = self.extract_tokens(
                    fvs.squeeze(0), gold_labels)

                for i, ent in enumerate(ents):
                    # convert to torch.tensor
                    inputs = torch.empty(
                        [len(ent),
                         self.bert_config.hidden_size]).to(self.device)
                    for j, token in enumerate(ent):
                        inputs[j, :] = token
                    target = ent_golds[i]
                    inputs = torch.mean(inputs, dim=0, keepdim=True)

                    # classification
                    outputs = self.mlp(inputs)
                    if target == 1:
                        if outputs < 0.5:
                            fn += 1
                        else:
                            tp += 1
                    else:
                        if outputs < 0.5:
                            tn += 1
                        else:
                            fp += 1

        return Score(tp, fp, tn, fn).calc_score()

    def train(self):
        # MLP
        self.mlp = MLP(self.bert_config.hidden_size)
        self.mlp.to(self.device)
        self.mlp.train()
        # learnging parameter settings
        params = list(self.mlp.parameters())
        if self.train_bert:
            params += list(self.bert_model.parameters())
        # loss
        self.criterion = BCEWithLogitsLoss()
        # optimizer
        self.optimizer = AdamW(params, lr=self.lr)
        num_train_steps = int(self.n_epoch * len(self.train_dataset) /
                              self.batch_size)
        num_warmup_steps = int(self.cut_frac * num_train_steps)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps, num_train_steps)

        try:
            best_dev1_f1, best_dev2_f1 = 0, 0
            # best_dev_f1 = 0
            itr = 1
            for epoch in range(1, self.n_epoch + 1):
                print("Epoch : {}".format(epoch))
                print("training...")
                for i in tqdm(
                        range(0, len(self.train_dataset), self.batch_size)):
                    # fvs, ents, batch_samples, inputs, outputs = None, None, None, None, None
                    itr += i
                    # create batch samples
                    if (i + self.batch_size) < len(self.train_dataset):
                        end_i = (i + self.batch_size)
                    else:
                        end_i = len(self.train_dataset)

                    batch_samples, batch_golds = [], []

                    for j in range(i, end_i):
                        # flush to Bert
                        fname, example = self.train_dataset[j]

                        fvs, input_ids, label_ids, gold_labels = self.transforms(
                            example, self.train_dataset.label_list)

                        # extract Element/Main tokens
                        ents, ent_golds, _ = self.extract_tokens(
                            fvs.squeeze(0), gold_labels)
                        for e in ents:
                            ent = torch.empty(
                                [len(e),
                                 self.bert_config.hidden_size]).to(self.device)
                            for k, t in enumerate(e):
                                ent[k, :] = t
                            batch_samples.append(torch.mean(ent, dim=0))
                        batch_golds.extend(ent_golds)

                    # convert to torch.tensor
                    inputs = torch.empty(
                        [len(batch_samples),
                         self.bert_config.hidden_size]).to(self.device)
                    for j, t in enumerate(batch_samples):
                        inputs[j, :] = t
                    targets = torch.tensor(batch_golds,
                                           dtype=torch.float).unsqueeze(1)

                    self.optimizer.zero_grad()
                    with torch.cuda.amp.autocast():
                        outputs = self.mlp(inputs)
                        loss = self.criterion(outputs, targets.to(self.device))
                        # loss = loss / 100
                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()

                    del fvs, ents, batch_samples, inputs, outputs
                    torch.cuda.empty_cache()

                    # write to SummaryWriter
                    self.writer.add_scalar("loss", loss.item(), itr)
                    self.writer.add_scalar(
                        "lr", self.optimizer.param_groups[0]["lr"], itr)

                # write to SummaryWriter
                if self.train_bert:
                    self.bert_model.eval()
                self.mlp.eval()
                # import pdb; pdb.set_trace()

                print("train data evaluation...")
                tr_acc, tr_rec, _, tr_prec, tr_f1 = self.eval(
                    self.train_dataset)
                print(
                    f"acc: {tr_acc}, rec: {tr_rec}, prec: {tr_prec}, f1: {tr_f1}"
                )
                self.writer.add_scalar("train/acc", tr_acc, epoch)
                self.writer.add_scalar("train/rec", tr_rec, epoch)
                self.writer.add_scalar("train/prec", tr_prec, epoch)
                self.writer.add_scalar("train/f1", tr_f1, epoch)
                # print("dev data evaluation...")
                # dev_acc, dev_rec, _, dev_prec, dev_f1 = self.eval(self.dev_dataset)
                # print(f"acc: {dev_acc}, rec: {dev_rec}, prec: {dev_prec}, f1: {dev_f1}")
                # self.writer.add_scalar("dev/acc", dev_acc, epoch)
                # self.writer.add_scalar("dev/rec", dev_rec, epoch)
                # self.writer.add_scalar("dev/prec", dev_prec, epoch)
                # self.writer.add_scalar("dev/f1", dev_f1, epoch)
                # self.writer.flush()
                print("dev1 data evaluation...")
                dev1_acc, dev1_rec, _, dev1_prec, dev1_f1 = self.eval(
                    self.dev1_dataset)
                print(
                    f"acc: {dev1_acc}, rec: {dev1_rec}, prec: {dev1_prec}, f1: {dev1_f1}"
                )
                self.writer.add_scalar("dev1/acc", dev1_acc, epoch)
                self.writer.add_scalar("dev1/rec", dev1_rec, epoch)
                self.writer.add_scalar("dev1/prec", dev1_prec, epoch)
                self.writer.add_scalar("dev1/f1", dev1_f1, epoch)
                self.writer.flush()
                print("dev2 data evaluation...")
                dev2_acc, dev2_rec, _, dev2_prec, dev2_f1 = self.eval(
                    self.dev2_dataset)
                print(
                    f"acc: {dev2_acc}, rec: {dev2_rec}, prec: {dev2_prec}, f1: {dev2_f1}"
                )
                self.writer.add_scalar("dev2/acc", dev2_acc, epoch)
                self.writer.add_scalar("dev2/rec", dev2_rec, epoch)
                self.writer.add_scalar("dev2/prec", dev2_prec, epoch)
                self.writer.add_scalar("dev2/f1", dev2_f1, epoch)
                self.writer.flush()
                if self.train_bert:
                    self.bert_model.train()
                self.mlp.train()

                if epoch % self.model_save_freq == 0:
                    curr_log_dir = self.log_dir / f"epoch_{epoch}"
                    if not curr_log_dir.exists():
                        curr_log_dir.mkdir()
                    if self.train_bert:
                        self.bert_model.save_pretrained(curr_log_dir)
                    torch.save(self.mlp.state_dict(),
                               curr_log_dir / "mlp.model")

                # if best_dev_f1 <= dev_f1:
                #     best_dev_f1 = dev_f1
                #     best_dev_epoch = epoch
                #     if self.train_bert:
                #         best_dev_model = copy.deepcopy(self.bert_model)
                #     best_dev_mlp = copy.deepcopy(self.mlp.state_dict())
                if best_dev1_f1 <= dev1_f1:
                    best_dev1_f1 = dev1_f1
                    best_dev1_epoch = epoch
                    if self.train_bert:
                        best_dev1_model = copy.deepcopy(self.bert_model).cpu()
                    best_dev1_mlp = copy.deepcopy(self.mlp).cpu().state_dict()
                if best_dev2_f1 <= dev2_f1:
                    best_dev2_f1 = dev2_f1
                    best_dev2_epoch = epoch
                    if self.train_bert:
                        best_dev2_model = copy.deepcopy(self.bert_model).cpu()
                    best_dev2_mlp = copy.deepcopy(self.mlp).cpu().state_dict()

        except KeyboardInterrupt:
            # del fvs, ents, batch_samples, inputs, outputs
            # print(f"Best epoch was #{best_dev_epoch}!\nSave params...")
            # save_dev_dir = Path(self.log_dir) / "best"
            # if not save_dev_dir.exists():
            #     save_dev_dir.mkdir()
            # if self.train_bert:
            #     best_dev_model.save_pretrained(save_dev_dir)
            # torch.save(best_dev_mlp, save_dev_dir / "mlp.model")
            # print("Training was successfully finished!")
            print(
                f"Best epoch was dev1: #{best_dev1_epoch}, dev2: #{best_dev2_epoch}!\nSave params..."
            )
            save_dev1_dir = Path(self.log_dir) / "dev1_best"
            if not save_dev1_dir.exists():
                save_dev1_dir.mkdir()
            save_dev2_dir = Path(self.log_dir) / "dev2_best"
            if not save_dev2_dir.exists():
                save_dev2_dir.mkdir()
            if self.train_bert:
                best_dev1_model.save_pretrained(save_dev1_dir)
                best_dev2_model.save_pretrained(save_dev2_dir)
            torch.save(best_dev1_mlp, save_dev1_dir / "mlp.model")
            torch.save(best_dev2_mlp, save_dev2_dir / "mlp.model")
            print("Training was successfully finished!")
            raise KeyboardInterrupt
        else:
            # print(f"Best epoch was #{best_dev_epoch}!\nSave params...")
            # save_dev_dir = Path(self.log_dir) / "best"
            # if not save_dev_dir.exists():
            #     save_dev_dir.mkdir()
            # if self.train_bert:
            #     best_dev_model.save_pretrained(save_dev_dir)
            # torch.save(best_dev_mlp, save_dev_dir / "mlp.model")
            # print("Training was successfully finished!")
            print(
                f"Best epoch was dev1: #{best_dev1_epoch}, dev2: #{best_dev2_epoch}!\nSave params..."
            )
            save_dev1_dir = Path(self.log_dir) / "dev1_best"
            if not save_dev1_dir.exists():
                save_dev1_dir.mkdir()
            save_dev2_dir = Path(self.log_dir) / "dev2_best"
            if not save_dev2_dir.exists():
                save_dev2_dir.mkdir()
            if self.train_bert:
                best_dev1_model.save_pretrained(save_dev1_dir)
                best_dev2_model.save_pretrained(save_dev2_dir)
            torch.save(best_dev1_mlp, save_dev1_dir / "mlp.model")
            torch.save(best_dev2_mlp, save_dev2_dir / "mlp.model")
            print("Training was successfully finished!")
            sys.exit()
示例#13
0
                e,
                train_iter,
                valid_iter,
                criterion=criterion)
        train_loss, train_acc, valid_loss, valid_acc = loggings[0], loggings[
            1], loggings[2], loggings[3]
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["valid_loss"].append(valid_loss)
        history["valid_acc"].append(valid_acc)
        if valid_loss <= valid_loss_min:
            print(
                f'\t⚠️ validation loss decreased ({valid_loss_min:.6f} ☛ {val_loss:.6f})'
            )
            print(f'\t📸 model snapshot saved')
            torch.save(net.state_dict(), f'utils/{args.network}.pth')
            valid_loss_min = val_loss

    end_time = time.time()
    total_time = end_time - start_time
    index_train_acc, value_train_acc = max(enumerate(history["train_acc"]),
                                           key=operator.itemgetter(1))
    index_val_acc, value_val_acc = max(enumerate(history["valid_acc"]),
                                       key=operator.itemgetter(1))
    print("➲ Best training accuracy was {} at epoch {}".format(
        value_train_acc, index_train_acc + 1))
    print("➲ Best valid accuracy was {} at epoch {}".format(
        value_val_acc, index_val_acc + 1))

    # 						----------------------------
    # --------------------- plotting statistical results
示例#14
0
                'batch_repulsive': br,
                'bandwidth_repulsive': bandwidth_repulsive,
                'lambda_repulsive': args.lambda_repulsive
            }
        else:
            kwargs = {}

        data, target = data.cpu(), target.cpu()
        info_batch = optimize(net,
                              optimizer,
                              batch=(data, target),
                              add_repulsive_constraint=args.repulsive
                              is not None,
                              **kwargs)
        step += 1
        for k, v in info_batch.items():
            experiment.log_metric('train_{}'.format(k), v, step=step)

# Save the model
if not Path.exists(savepath / 'models'):
    os.makedirs(savepath / 'models')

model_path = savepath / 'models' / '{}_{}epochs.pt'.format(
    model_name, epoch + 1)
if not Path.exists(model_path):
    torch.save(net.state_dict(), model_path)
else:
    raise ValueError(
        'Error trying to save file at location {}: File already exists'.format(
            model_path))
def train_model(config, gpu_id, save_dir, exp_name):

    # Instantiating the model
    model_type = config.get('model_type', 'MLP')
    if model_type == "MLP":
        model = MLP(config['input_size'],
                    config["hidden_layers"],
                    1,
                    config["nonlinearity"],
                    config["initialization"],
                    config["dropout"],
                    verbose=True)
    elif model_type == "CNN":
        model = CNN(config["initialization"],
                    config["is_batch_norm"],
                    verbose=True)
    else:
        raise ValueError(
            'config["model_type"] not supported : {}'.format(model_type))

    if config['resume']:
        model.load_state_dict(
            torch.load(os.path.join(save_dir, exp_name, "model")))

    # If GPU is available, sends model and dataset on the GPU
    if torch.cuda.is_available():
        model.cuda(gpu_id)
        print("USING GPU-{}".format(gpu_id))

    # Optimizer and Loss Function
    optimizer = optim.RMSprop(model.parameters(), lr=config['lr'])
    loss_fn = nn.CrossEntropyLoss()
    """ Trains an agent with (stochastic) Policy Gradients on Pong. Uses OpenAI Gym. """
    env = gym.make("Pong-v0")
    observation = env.reset()

    prev_x = None  # used in computing the difference frame
    y_list, LL_list, reward_list = [], [], []
    running_reward = None
    reward_sum = 0
    episode_number = 0

    start = time.time()

    # Initializing recorders
    update = 0
    loss_tape = []
    our_score_tape = []
    opponent_score_tape = []
    our_score = 0
    opponent_score = 0

    # TRAINING LOOP
    while update < config['max_updates']:

        if config['render']: env.render()

        # preprocess the observation and set input to network to be difference image
        cur_x = utils.preprocess(observation,
                                 data_format=config['data_format'])
        if prev_x is None:
            x = np.zeros(cur_x.shape)
        else:
            x = cur_x - prev_x
        prev_x = cur_x

        x_torch = Variable(torch.from_numpy(x).float(), requires_grad=False)
        if config['data_format'] == "array":
            x_torch = x_torch.unsqueeze(dim=0).unsqueeze(dim=0)

        if torch.cuda.is_available():
            x_torch = x_torch.cuda(gpu_id)

        # Feedforward through the policy network
        action_prob = model(x_torch)

        # Sample an action from the returned probability
        if np.random.uniform() < action_prob.cpu().data.numpy():
            action = 2  # UP
        else:
            action = 3  # DOWN

        # record the log-likelihoods
        y = 1 if action == 2 else 0  # a "fake label"
        NLL = -y * torch.log(action_prob) - (1 - y) * torch.log(1 -
                                                                action_prob)
        LL_list.append(NLL)
        y_list.append(
            y
        )  # grad that encourages the action that was taken to be taken        TODO: the tensor graph breaks here. Find a way to backpropagate the PG error.

        # step the environment and get new measurements
        observation, reward, done, info = env.step(action)
        reward_sum += reward

        reward_list.append(
            reward
        )  # record reward (has to be done after we call step() to get reward for previous action)

        if done:  # an episode finished (an episode ends when one of the player wins 21 games)
            episode_number += 1

            # Computes loss and reward for each step of the episode
            R = torch.zeros(1, 1)
            loss = 0
            for i in reversed(range(len(reward_list))):
                R = config['gamma'] * R + reward_list[i]
                Return_i = Variable(R)
                if torch.cuda.is_available():
                    Return_i = Return_i.cuda(gpu_id)
                loss = loss + (LL_list[i] *
                               (Return_i)).sum()  # .expand_as(LL_list[i])
            loss = loss / len(reward_list)
            print(loss)

            # Backpropagates to compute the gradients
            loss.backward()

            y_list, LL_list, reward_list = [], [], []  # reset array memory

            # Performs parameter update every config['mb_size'] episodes
            if episode_number % config['mb_size'] == 0:

                # Takes one training step
                optimizer.step()

                # Empties the gradients
                optimizer.zero_grad()

                stop = time.time()
                print("PARAMETER UPDATE ------------ {}".format(stop - start))
                start = time.time()

                utils.save_results(save_dir, exp_name, loss_tape,
                                   our_score_tape, opponent_score_tape, config)

                update += 1
                if update % 10 == 0:
                    torch.save(
                        model.state_dict(),
                        os.path.join(save_dir, exp_name,
                                     "model_" + model.name()))

            # Records the average loss and score of the episode
            loss_tape.append(loss.cpu().data.numpy())

            our_score_tape.append(our_score)
            opponent_score_tape.append(opponent_score)
            our_score = 0
            opponent_score = 0

            # boring book-keeping
            if running_reward is None:
                running_reward = reward_sum
            else:
                running_reward = running_reward * 0.99 + reward_sum * 0.01
            print(
                'resetting env. episode reward total was {0:.2f}. running mean: {1:.2f}'
                .format(reward_sum, running_reward))

            reward_sum = 0
            observation = env.reset()  # reset env
            prev_x = None

        if reward != 0:  # Pong has either +1 or -1 reward exactly when game ends.
            if reward == -1:
                opponent_score += 1
                print('ep {0}: game finished, reward: {1:.2f}'.format(
                    episode_number, reward))
            else:
                our_score += 1
                print(
                    'ep {0}: game finished, reward: {1:.2f} !!!!!!!!!'.format(
                        episode_number, reward))
示例#16
0
def main():
    # Initialize environment
    env = UnityEnvironment(file_name='../env/Pong/Pong')

    default_brain = env.brain_names[0]
    brain = env.brains[default_brain]

    env_info = env.reset(train_mode=True)[default_brain]

    obs_dim = env_info.vector_observations[0].shape[0]
    act_num = brain.vector_action_space_size[0]
    print('State dimension:', obs_dim)
    print('Action number:', act_num)

    # Set a random seed
    np.random.seed(0)
    torch.manual_seed(0)

    # Create a SummaryWriter object by TensorBoard
    dir_name = 'runs/' + 'dqn/' + 'Pong_dqn' + '_' + time.ctime()
    writer = SummaryWriter(log_dir=dir_name)

    # Main network
    qf = MLP(obs_dim, act_num).to(device)
    # Target network
    qf_target = MLP(obs_dim, act_num).to(device)

    # Initialize target parameters to match main parameters
    qf_target.load_state_dict(qf.state_dict())

    # Create an optimizer
    qf_optimizer = optim.Adam(qf.parameters(), lr=1e-3)

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim, 1, args.buffer_size)

    step_count = 0
    sum_returns = 0.
    num_episodes = 0
    recent_returns = deque(maxlen=10)

    start_time = time.time()

    for episode in range(1, args.episode_num + 1):
        total_reward = 0.

        env_info = env.reset(train_mode=True)[default_brain]
        obs = env_info.vector_observations[0]
        done = False

        # Keep interacting until agent reaches a terminal state.
        while not done:
            step_count += 1

            # Collect experience (s, a, r, s') using some policy
            action = select_action(torch.Tensor(obs).to(device), act_num, qf)

            env_info = env.step(int(action))[default_brain]

            next_obs = env_info.vector_observations[0]
            reward = env_info.rewards[0]
            done = env_info.local_done[0]

            # Add experience to replay buffer
            replay_buffer.add(obs, action, reward, next_obs, done)

            # Start training when the number of experience is greater than batch size
            if step_count > args.batch_size:
                batch = replay_buffer.sample(args.batch_size)
                train_model(qf, qf_target, qf_optimizer, batch, step_count)

            total_reward += reward
            obs = next_obs

        recent_returns.append(total_reward)
        sum_returns += total_reward
        num_episodes += 1
        average_return = sum_returns / num_episodes if num_episodes > 0 else 0.0

        # Log experiment result for training episodes
        writer.add_scalar('Train/AverageReturns', average_return, episode)
        writer.add_scalar('Train/EpisodeReturns', sum_returns, episode)

        if episode % 10 == 0:
            print('---------------------------------------')
            print('Episodes:', episode)
            print('Steps:', step_count)
            print('AverageReturn:', round(average_return, 2))
            print('RecentReturn:', np.mean(recent_returns))
            print('Time:', int(time.time() - start_time))
            print('---------------------------------------')

        # Save a training model
        if (np.mean(recent_returns)) >= args.threshold_return:
            print('Recent returns {} exceed threshold return. So end'.format(
                np.mean(recent_returns)))
            if not os.path.exists('./save_model'):
                os.mkdir('./save_model')

            ckpt_path = os.path.join('./save_model/' + 'Pong_dqn' + '_ep_' + str(episode) \
                                                                  + '_rt_' + str(round(average_return, 2)) \
                                                                  + '_t_' + str(int(time.time() - start_time)) + '.pt')
            torch.save(qf.state_dict(), ckpt_path)
            break

    env.close()
示例#17
0
class REINFORCE:
    def __init__(self, obs_space_size, hidden_sizes, action_space_size,
                 learning_rate, use_cuda, gpu_id):

        self.action_space_size = action_space_size
        self.use_cuda = use_cuda
        self.gpu_id = gpu_id

        # Initializes the policy network and optimizer
        self.policy = MLP(obs_space_size,
                          hidden_sizes,
                          action_space_size,
                          "distribution",
                          "relu",
                          "standard",
                          name="PolicyNetwork",
                          verbose=True)
        self.optimizer = torch.optim.Adam(self.policy.parameters(),
                                          lr=learning_rate)

        # Creates counters
        self.action_count = np.zeros(shape=(self.action_space_size, ))

        self.explore_count = 0
        self.exploit_count = 0

        # If GPU is available, sends model on GPU
        if torch.cuda.is_available() and self.use_cuda:
            self.policy.cuda(gpu_id)
            print("USING GPU-{}".format(gpu_id))

        self.policy.train()

    def select_action(self, observation):

        # Transforms the state into a torch Variable
        x = Variable(torch.Tensor([observation]))

        if torch.cuda.is_available() and self.use_cuda:
            x = x.cuda(self.gpu_id)

        # Forward propagation through policy network
        action_probs = self.policy(x)

        # Samples an action
        action = action_probs.multinomial().data

        # Negative log-likelihood of sampled action
        NLL = -torch.log(action_probs[:, action[0, 0]]).view(1, -1)

        if int(action) == int(torch.max(action_probs, 1)[1].cpu().data):
            self.exploit_count += 1
        else:
            self.explore_count += 1
        self.action_count[int(action)] += 1

        return int(action), NLL

    def compute_gradients(self, reward_list, NLL_list, gamma):

        R = torch.zeros(1, 1)
        loss = 0

        # Iterates through the episode in reverse order to compute return for each step
        for i in reversed(range(len(reward_list))):

            # Discounts reward
            R = gamma * R + reward_list[i]
            Return_i = Variable(R)
            if torch.cuda.is_available() and self.use_cuda:
                Return_i = Return_i.cuda(self.gpu_id)

            # Loss is the NLL at each step weighted by the return for that step
            loss = loss + (NLL_list[i] * Return_i).squeeze()

        # Average to get the total loss
        loss = loss / len(reward_list)

        # Backpropagation to compute the gradients
        loss.backward()

        return loss.cpu().data.numpy()

    def update_parameters(self):
        # Clips the gradient and apply the update
        torch.nn.utils.clip_grad_norm(self.policy.parameters(), 40)
        self.optimizer.step()
        self.optimizer.zero_grad()

    def save_policy(self, directory):
        torch.save(self.policy.state_dict(),
                   os.path.join(directory, self.policy.name + "_ckpt.pkl"))

    def load_policy(self, directory):
        model.load_state_dict(
            torch.load(os.path.join(directory, "model_" + self.policy.name)))

    def reset_counters(self):

        self.action_count = np.zeros(shape=(self.action_space_size, ))

        self.explore_count = 0
        self.exploit_count = 0
示例#18
0
def train(lr=args.lr,
          n_hidden=args.n_hidden,
          batch_size=args.batch_size,
          dropout=args.dropout,
          valid_freq=3000,
          disp_freq=1000,
          save_freq=100000,
          max_epochs=args.n_epoch,
          patience=15,
          save_name=args.save_name,
          save_dir=args.save_dir,
          device=args.device):
    # Load train and valid dataset
    print('loading train')
    with open(args.train_path, 'rb') as f:
        train_val_y = pickle.load(f)
        train_val_x = pickle.load(f)

    print('loading english test')
    with open(args.en_test_path, 'rb') as f:
        en_test_y = pickle.load(f)
        en_test_x = pickle.load(f)

    print('loading french test')
    with open(args.fr_test_path, 'rb') as f:
        fr_test_y = pickle.load(f)
        fr_test_x = pickle.load(f)

    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=1125)
    for train_index, test_index in sss.split(train_val_x, train_val_y):
        train_y = train_val_y[train_index]
        train_x = train_val_x[train_index]
        valid_y = train_val_y[test_index]
        valid_x = train_val_x[test_index]

    print('Number of training sample: %d' % train_x.shape[0])
    print('Number of validation sample: %d' % valid_x.shape[0])
    print('Number of english testing sample: %d' % en_test_x.shape[0])
    print('Number of french testing sample: %d' % fr_test_x.shape[0])
    print('-' * 100)

    kf_valid = get_minibatches_idx(len(valid_y), batch_size)
    kf_en_test = get_minibatches_idx(len(en_test_y), batch_size)
    kf_fr_test = get_minibatches_idx(len(fr_test_y), batch_size)

    # Loader parameter: use CUDA pinned memory for faster data loading
    pin_memory = (device == args.device)
    # Test set

    n_emb = train_x.shape[1]
    n_class = len(set(train_y))
    best_valid_acc = None
    bad_counter = 0

    uidx = 0  # the number of update done
    estop = False  # early stop switch
    net = MLP(n_mlp_layer=args.n_mlp_layers,
              n_hidden=args.n_hidden,
              dropout=args.dropout,
              n_class=n_class,
              n_emb=n_emb,
              device=args.device)

    if args.load_net != '':
        assert os.path.exists(
            args.load_net), 'Path to pretrained net does not exist'
        net.load_state_dict(torch.load(args.load_net))
        print('Load exists model stored at: ', args.load_net)

    if args.device == 'gpu':
        net = net.cuda()

    # Begin Training
    net.train()
    print('-' * 100)
    print('Model structure: ')
    print('MLP baseline')
    print(net.main)
    print('-' * 100)
    print('Parameters for tuning: ')
    print(net.state_dict().keys())
    print('-' * 100)

    # Define optimizer
    assert args.optimizer in [
        'SGD', 'Adam', "RMSprop", "LBFGS", "Rprop", "ASGD", "Adadelta",
        "Adagrad", "Adamax"
    ], 'Please choose either SGD or Adam'
    if args.optimizer == 'SGD':
        optimizer = optim.SGD(lr=lr,
                              params=filter(lambda p: p.requires_grad,
                                            net.parameters()),
                              momentum=0.9)
    else:
        optimizer = getattr(optim, args.optimizer)(params=filter(
            lambda p: p.requires_grad, net.parameters()),
                                                   lr=lr)

    #lambda1 = lambda epoch: epoch // 30
    lambda2 = lambda epoch: 0.98**epoch
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda2])
    #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max')
    try:
        for eidx in range(max_epochs):
            scheduler.step()
            # print('Training mode on: ' ,net.training)
            start_time = time.time()
            n_samples = 0
            # Get new shuffled index for the training set
            kf = get_minibatches_idx(len(train_y), batch_size, shuffle=True)

            for _, train_index in kf:
                # Remove gradient from previous batch
                #net.zero_grad()
                optimizer.zero_grad()
                uidx += 1
                y_batch = torch.autograd.Variable(
                    torch.from_numpy(train_y[train_index]).long())
                x_batch = torch.autograd.Variable(
                    torch.from_numpy(train_x[train_index]).float())
                if net.device == 'gpu':
                    y_batch = y_batch.cuda()
                scores = net.forward(x_batch)
                loss = net.loss(scores, y_batch)

                loss.backward()
                optimizer.step()
                n_samples += len(x_batch)
                gradient = 0

                # For logging gradient information
                for name, w in net.named_parameters():
                    if w.grad is not None:
                        w_grad = torch.norm(w.grad.data, 2)**2
                        gradient += w_grad
                gradient = gradient**0.5
                if np.mod(uidx, disp_freq) == 0:
                    print('Epoch ', eidx, 'Update ', uidx, 'Cost ',
                          loss.data[0], 'Gradient ', gradient)

                if save_name and np.mod(uidx, save_freq) == 0:
                    print('Saving...')
                    torch.save(
                        net.state_dict(), '%s/%s_epoch%d_update%d.net' %
                        (save_dir, save_name, eidx, uidx))

                if np.mod(uidx, valid_freq) == 0:
                    print("=" * 50)
                    print('Evaluation on validation set: ')
                    kf_valid = get_minibatches_idx(len(valid_y), batch_size)
                    top_1_acc, top_n_acc = eval.net_evaluation(
                        net, kf_valid, valid_x, valid_y)
                    #scheduler.step(top_1_acc)

                    # Save best performance state_dict for testing
                    if best_valid_acc is None:
                        best_valid_acc = top_1_acc
                        best_state_dict = net.state_dict()
                        torch.save(best_state_dict,
                                   '%s/%s_best.net' % (save_dir, save_name))
                    else:
                        if top_1_acc > best_valid_acc:
                            print(
                                'Best validation performance so far, saving model parameters'
                            )
                            print("*" * 50)
                            bad_counter = 0  # reset counter
                            best_valid_acc = top_1_acc
                            best_state_dict = net.state_dict()
                            torch.save(
                                best_state_dict,
                                '%s/%s_best.net' % (save_dir, save_name))
                        else:
                            bad_counter += 1
                            print('Validation accuracy: ', 100 * top_1_acc)
                            print('Getting worse, patience left: ',
                                  patience - bad_counter)
                            print('Best validation accuracy  now: ',
                                  100 * best_valid_acc)
                            # Learning rate annealing
                            lr /= args.lr_anneal
                            print('Learning rate annealed to: ', lr)
                            print('*' * 100)
                            if args.optimizer == 'SGD':
                                optimizer = optim.SGD(
                                    lr=lr,
                                    params=filter(lambda p: p.requires_grad,
                                                  net.parameters()),
                                    momentum=0.9)
                            else:
                                optimizer = getattr(optim, args.optimizer)(
                                    params=filter(lambda p: p.requires_grad,
                                                  net.parameters()),
                                    lr=lr)
                            if bad_counter > patience:
                                print('-' * 100)
                                print('Early Stop!')
                                estop = True
                                break

            epoch_time = time.time() - start_time
            print('Epoch processing time: %.2f s' % epoch_time)
            print('Seen %d samples' % n_samples)
            if estop:
                break
        print('-' * 100)
        print('Training finish')
        best_state_dict = torch.load('%s/%s_best.net' % (save_dir, save_name))
        torch.save(net.state_dict(), '%s/%s_final.net' % (save_dir, save_name))
        net.load_state_dict(best_state_dict)

        # add self connection
        print('Evaluation on validation set: ')
        kf_valid = get_minibatches_idx(len(valid_y), batch_size)
        eval.net_evaluation(net, kf_valid, valid_x, valid_y)

        # Evaluate model on test set
        print('Evaluation on test set: ')
        print('Evaluation on English testset: ')
        eval.net_evaluation(net, kf_en_test, en_test_x, en_test_y)
        print('Evaluation on French testset: ')
        eval.net_evaluation(net, kf_fr_test, fr_test_x, fr_test_y)
    except KeyboardInterrupt:
        print('-' * 100)
        print("Training interrupted, saving final model...")
        best_state_dict = torch.load('%s/%s_best.net' % (save_dir, save_name))
        torch.save(net.state_dict(), '%s/%s_final.net' % (save_dir, save_name))
        net.load_state_dict(best_state_dict)
        print('Evaluation on validation set: ')
        kf_valid = get_minibatches_idx(len(valid_y), batch_size)
        eval.net_evaluation(net, kf_valid, valid_x, valid_y)

        # Evaluate model on test set
        print('Evaluation on English testset: ')
        eval.net_evaluation(net, kf_en_test, en_test_x, en_test_y)
        print('Evaluation on French testset: ')
        eval.net_evaluation(net, kf_fr_test, fr_test_x, fr_test_y)
示例#19
0
model.reset_parameters()
print("--------------------------")
print("Training...")
for epoch in range(args.epochs):
	loss_tra,train_ep = train(model,args.dev,train_loader,optimizer)
	f1_val = validate(model, args.dev, valid_loader, evaluator)
	train_time+=train_ep
	if(epoch+1)%5 == 0: 
		print(f'Epoch:{epoch+1:02d},'
			f'Train_loss:{loss_tra:.3f}',
			f'Valid_acc:{100*f1_val:.3f}%',
			f'Time_cost{train_time:.3f}')
	if f1_val > best:
		best = f1_val
		best_epoch = epoch
		torch.save(model.state_dict(), checkpt_file)
		bad_counter = 0
	else:
		bad_counter += 1

	if bad_counter == args.patience:
		break

test_acc = test(model, args.dev, test_loader, evaluator,checkpt_file)
print(f"Train cost: {train_time:.2f}s")
print('Load {}th epoch'.format(best_epoch))
print(f"Test accuracy:{100*test_acc:.2f}%")

memory_main = 1024 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/2**30
memory=memory_main-memory_dataset
print("Memory overhead:{:.2f}GB".format(memory))
class DQN:
    def __init__(self,
                 n_states,
                 n_actions,
                 gamma=0.99,
                 epsilon_start=0.9,
                 epsilon_end=0.05,
                 epsilon_decay=200,
                 memory_capacity=10000,
                 policy_lr=0.01,
                 batch_size=128,
                 device="cpu"):

        self.n_actions = n_actions  # 总的动作个数
        self.device = device  # 设备,cpu或gpu等
        self.gamma = gamma  # 奖励的折扣因子
        # e-greedy策略相关参数
        self.actions_count = 0  # 用于epsilon的衰减计数
        self.epsilon = 0
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.policy_net = MLP(n_states, n_actions).to(self.device)
        self.target_net = MLP(n_states, n_actions).to(self.device)
        # target_net的初始模型参数完全复制policy_net
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # 不启用 BatchNormalization 和 Dropout
        # 可查parameters()与state_dict()的区别,前者require_grad=True
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
        self.loss = 0
        self.memory = ReplayBuffer(memory_capacity)

    def choose_action(self, state, train=True):
        '''选择动作
        '''
        if train:
            self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
                math.exp(-1. * self.actions_count / self.epsilon_decay)
            self.actions_count += 1
            if random.random() > self.epsilon:
                with torch.no_grad():
                    # 先转为张量便于丢给神经网络,state元素数据原本为float64
                    # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价
                    state = torch.tensor([state],
                                         device=self.device,
                                         dtype=torch.float32)
                    # 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>)
                    q_value = self.policy_net(state)
                    # tensor.max(1)返回每行的最大值以及对应的下标,
                    # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0]))
                    # 所以tensor.max(1)[1]返回最大值对应的下标,即action
                    action = q_value.max(1)[1].item()
            else:
                action = random.randrange(self.n_actions)
            return action
        else:
            with torch.no_grad():  # 取消保存梯度
                # 先转为张量便于丢给神经网络,state元素数据原本为float64
                # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价
                state = torch.tensor(
                    [state], device='cpu', dtype=torch.float32
                )  # 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>)
                q_value = self.target_net(state)
                # tensor.max(1)返回每行的最大值以及对应的下标,
                # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0]))
                # 所以tensor.max(1)[1]返回最大值对应的下标,即action
                action = q_value.max(1)[1].item()
            return action

    def update(self):

        if len(self.memory) < self.batch_size:
            return
        # 从memory中随机采样transition
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
            self.batch_size)
        '''转为张量
        例如tensor([[-4.5543e-02, -2.3910e-01,  1.8344e-02,  2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02,  2.3400e-01]])'''
        state_batch = torch.tensor(state_batch,
                                   device=self.device,
                                   dtype=torch.float)
        action_batch = torch.tensor(action_batch,
                                    device=self.device).unsqueeze(
                                        1)  # 例如tensor([[1],...,[0]])
        reward_batch = torch.tensor(
            reward_batch, device=self.device,
            dtype=torch.float)  # tensor([1., 1.,...,1])
        next_state_batch = torch.tensor(next_state_batch,
                                        device=self.device,
                                        dtype=torch.float)
        done_batch = torch.tensor(np.float32(done_batch),
                                  device=self.device).unsqueeze(
                                      1)  # 将bool转为float然后转为张量
        '''计算当前(s_t,a)对应的Q(s_t, a)'''
        '''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
        q_values = self.policy_net(state_batch).gather(
            dim=1, index=action_batch)  # 等价于self.forward
        # 计算所有next states的V(s_{t+1}),即通过target_net中选取reward最大的对应states
        next_state_values = self.target_net(next_state_batch).max(
            1)[0].detach()  # 比如tensor([ 0.0060, -0.0171,...,])
        # 计算 expected_q_value
        # 对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward
        expected_q_values = reward_batch + self.gamma * \
            next_state_values * (1-done_batch[0])
        # self.loss = F.smooth_l1_loss(q_values,expected_q_values.unsqueeze(1)) # 计算 Huber loss
        self.loss = nn.MSELoss()(q_values,
                                 expected_q_values.unsqueeze(1))  # 计算 均方误差loss
        # 优化模型
        self.optimizer.zero_grad(
        )  # zero_grad清除上一步所有旧的gradients from the last step
        # loss.backward()使用backpropagation计算loss相对于所有parameters(需要gradients)的微分
        self.loss.backward()
        for param in self.policy_net.parameters():  # clip防止梯度爆炸
            param.grad.data.clamp_(-1, 1)

        self.optimizer.step()  # 更新模型

    def save_model(self, path):
        torch.save(self.target_net.state_dict(), path)

    def load_model(self, path):
        self.target_net.load_state_dict(torch.load(path))
示例#21
0
        optimizer.zero_grad()
        output = net(data) + args.beta * prior(data).detach()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        step += 1

        # experiment.log_metric('train_loss', loss.item(), step=step)

# Save the model
if not Path.exists(savepath / 'models'):
    os.makedirs(savepath / 'models')

if not Path.exists(savepath / 'models-prior'):
    os.makedirs(savepath / 'models-prior')

model_path = savepath / 'models' / '{}_{}epochs.pt'.format(
    model_name, epoch + 1)
if not Path.exists(model_path):
    torch.save(net.state_dict(), model_path)

prior_path = savepath / 'models-prior' / '{}-prior_{}epochs.pt'.format(
    model_name, epoch + 1)
if not Path.exists(prior_path):
    torch.save(prior.state_dict(), prior_path)
else:
    raise ValueError(
        'Error trying to save file at location {}: File already exists'.format(
            model_path))
	
	f02sp = F02SP(FFTSIZE,FS)
	
	f0 = 0.1 * np.arange(400,5000+1) # input, 0.1~800 [Hz]
	sp = f02sp.get_sp(f0)
	
	# f0 = torch.from_numpy(f0[:,np.newaxis]).to(dtype).to(device)
	f0 = torch.from_numpy(np.f0).to(dtype).to(device)
	sp = torch.from_numpy(sp).to(dtype).to(device)
	
	# model configure
	model = MLP(in_dim=1, out_dim=FFTSIZE//2+1, numlayer=numlayer, numunit=numunit)
	model = model.to(device)
	criterion = nn.MSELoss()
	optimizer = optim.Adam(model.parameters(), lr=0.001)
	
	# train
	if batchsize is None:
		losslog = batch_train(model, optimizer, criterion, x=f0, y=sp, nepoch=nepoch)
	else:
		losslog = minibatch_train(model, optimizer, criterion, x=f0, y=sp, nepoch=nepoch, batchsize=int(batchsize))
	
	# loss log save
	with open(outlosslog, mode="w") as f:
		f.write(losslog)
	
	# model save
	torch.save(model.state_dict(), outmodel)
	sys.exit(0)

示例#23
0
def run():
    print(f'Running from {os.getcwd()}')
    train_config, val_config = get_split_configs()
    print(f'Running with\n\ttrain_config: {train_config}\n\tval_config: {val_config}')

    train = AugMNISTDataset(transforms=['color'], config=train_config)
    val = AugMNISTDataset(transforms=['color'], config=val_config)
    train_dataloader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=args.batch_size, num_workers=8)
    val_dataloader = torch.utils.data.DataLoader(val, shuffle=True, batch_size=1000, num_workers=0)

    mlp_width = 512
    if args.use_l0:
        e_model = L0MLP(args.n_hidden, args.input_dim, mlp_width, 1).to(args.device)
        d_model = L0MLP(args.n_hidden, args.input_dim, mlp_width, 1).to(args.device)
    else:
        e_model = MLP(args.n_hidden, args.input_dim, mlp_width, 1).to(args.device)
        d_model = MLP(args.n_hidden, args.input_dim, mlp_width, 1).to(args.device)


    #summary(e_model, (13,))
    #summary(d_model, (13,))

    if args.optimizer == 'sgd':
        e_opt = torch.optim.SGD(e_model.parameters(), momentum=0.9, lr=args.lr)
        d_opt = torch.optim.SGD(d_model.parameters(), momentum=0.9, lr=args.lr)
    elif args.optimizer == 'adam':
        e_opt = torch.optim.Adam(e_model.parameters(), lr=args.lr)
        d_opt = torch.optim.Adam(d_model.parameters(), lr=args.lr)
    step = 0
    task = generate_task()
    decay_epochs = [60,90,120,150]
    e_sched = torch.optim.lr_scheduler.MultiStepLR(e_opt, milestones=decay_epochs, gamma=0.1)
    d_sched = torch.optim.lr_scheduler.MultiStepLR(d_opt, milestones=decay_epochs, gamma=0.1)
    for epoch in range(args.epochs):
        for idx, samples in enumerate(train_dataloader):
            features = get_features(samples).to(args.device)
            entangled_features = get_features(samples, entangle=True).to(args.device)
            labels = get_labels(samples, task).to(args.device)

            if args.use_l0:
                e_out, l0_e  = e_model(entangled_features)
                d_out, l0_d = d_model(features)
            else:
                e_out = e_model(entangled_features)
                d_out = d_model(features)

            e_pred = e_out > 0
            e_acc = (e_pred == labels).float().mean()
            d_pred = d_out > 0
            d_acc = (d_pred == labels).float().mean()

            e_bce = F.binary_cross_entropy_with_logits(e_out, labels)
            e_loss = e_bce
            d_bce = F.binary_cross_entropy_with_logits(d_out, labels)
            d_loss = d_bce

            # L0
            if args.use_l0:
                l0_coef = 1e-1
                d_loss += l0_coef * l0_d / len(samples)
                e_loss += l0_coef * l0_e / len(samples)

            # L1
            if epoch <= args.rampup_begin:
                l1_coef = args.warmup_l1
            else:
                l1_coef = args.warmup_l1 + args.l1 / (args.warmup_l1 + args.l1) * min(args.l1, args.l1 * (float(epoch) - args.rampup_begin) / (args.rampup_end-args.rampup_begin))

            d_loss += l1_coef * l1(d_model)
            e_loss += l1_coef * l1(e_model)

            e_loss.backward()
            e_grad = torch.nn.utils.clip_grad_norm_(e_model.parameters(), 100)
            e_opt.step()
            e_opt.zero_grad()

            d_loss.backward()
            d_grad = torch.nn.utils.clip_grad_norm_(d_model.parameters(), 100)
            d_opt.step()
            d_opt.zero_grad()

            if step % 250 == 0:
                stats = {}
                stats['step'] = step
                stats['train_acc/e'], stats['train_acc/d']  = e_acc, d_acc
                stats['train_loss/e'], stats['train_loss/d']  = e_loss, d_loss
                stats['train_bce/e'], stats['train_bce/d']  = e_bce, d_bce

                if args.warmup_l1 + args.l1 > 0:
                    stats['l1_coef'] = l1_coef

                d_nonzero, d_params = nonzero_params(d_model)
                e_nonzero, e_params = nonzero_params(e_model)
                stats['d_nonzero'], stats['e_nonzero'] = d_nonzero, e_nonzero

                with torch.no_grad():
                    val_samples = next(iter(val_dataloader))
                    val_features = get_features(val_samples).to(args.device)
                    val_entangled_features = get_features(val_samples, entangle=True).to(args.device)
                    val_labels = get_labels(val_samples, task)

                    if args.use_l0:
                        e_out = copy_and_zero(e_model)(val_entangled_features)[0].cpu()
                        d_out = copy_and_zero(d_model)(val_features)[0].cpu()
                    else:
                        e_out = copy_and_zero(e_model)(val_entangled_features).cpu()
                        d_out = copy_and_zero(d_model)(val_features).cpu()

                    stats['val_auc/e'] = metrics.roc_auc_score(val_labels, e_out)
                    stats['val_auc/d'] = metrics.roc_auc_score(val_labels, d_out)
                    stats['lr/e'], stats['lr/d'] = e_sched.get_lr()[0], d_sched.get_lr()[0]

                    e_pred = e_out > 0
                    e_acc = (e_pred == val_labels).float().mean()
                    d_pred = d_out > 0
                    d_acc = (d_pred == val_labels).float().mean()

                    stats['val_acc/e'], stats['val_acc/d'] = e_acc, d_acc

                    # Fetch k wrong predictions
                    #k = 10
                    #e_wrong_mask = [e_pred != val_labels]
                    #d_wrong_mask = [d_pred != val_labels]
                    #wrong_preds_e, ftrs_e = e_out[e_wrong_mask][:k], val_entangled_features[:k]
                    #wrong_preds_d, ftrs_d = d_out[d_wrong_mask][:k], val_features[:k]

                to_save = {
                    'd_model': d_model.state_dict(),
                    'e_model': e_model.state_dict(),
                    'd_opt': d_opt.state_dict(),
                    'e_opt': e_opt.state_dict()
                }
                torch.save(to_save, 'checkpoint.pt')
                if args.log_wandb:
                    wandb.log(stats)
                else:
                    print_stats(stats)

            step += 1
        e_sched.step()
        d_sched.step()
def train(args, logger, model_save_dir, val_dataset, test_dataset,
          train_dataset):
    # set seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    pretrain_embed = pickle.load(
        open('../{}/{}'.format(args.embed_dir, args.embed), 'rb'))

    try:
        pretrain_embed = torch.from_numpy(pretrain_embed).float()
    except:
        pretrain_embed = pretrain_embed.float()

    dataLoader = DataLoader(train_dataset,
                            batch_size=args.batch_sz,
                            shuffle=True)
    if args.model == 'MLP':
        model = MLP(args.hidden_dim, pretrain_embed)
    elif args.model == 'MLP3':
        model = MLP3Diff(args.hidden_dim, pretrain_embed)
    elif args.model == 'BiLinear':
        model = BiLinearDiff1(args.hidden_dim, pretrain_embed)
    else:
        model = BiLinearDiffH(args.hidden_dim, pretrain_embed)

    # model = ListMaxTransformer(args.hidden_dim, pretrain_embed)
    if torch.cuda.is_available():
        model.cuda()

    criterion = torch.nn.MSELoss()
    # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=args.gamma)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_dev_loss = float('+inf')
    best_dev_model = None
    best_dev_test_loss = 0
    counter = 0

    for epoch in range(1, args.n_epoch + 1):
        train_loss = 0
        model.train()
        iteration = 0
        optimizer.zero_grad()

        for batch in dataLoader:
            x = torch.stack(batch['input'])  # 5 x bz
            y = batch['label'].float()  # bz

            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()

            output = model(x)
            loss = criterion(output, y)
            train_loss += loss.item()
            loss.backward()
            nn.utils.clip_grad_norm(model.parameters(), 5)
            optimizer.step()

            iteration += 1
            # if iteration % args.iter_print == 0:
            #     logger.info('{}-{}-{}-{}'.format(epoch, iteration, train_loss, train_acc))

        train_loss = train_loss / len(dataLoader)
        dev_loss = val(model, val_dataset)
        test_loss = val(model, test_dataset)

        # scheduler.step()

        if dev_loss < best_dev_loss:
            best_dev_model = model.state_dict().copy()
            best_dev_loss = dev_loss
            best_dev_test_loss = test_loss
            counter = 0
        else:
            counter += 1

        if epoch % 5 == 0:
            logger.info('=================================================')
            logger.info('TRAIN: epoch:{}-loss:{}'.format(epoch, train_loss))
            logger.info('DEV: epoch:{}-loss:{}'.format(epoch, dev_loss))
            logger.info('TEST: epoch:{}-loss:{}'.format(epoch, test_loss))
            logger.info('BEST-DEV-LOSS: {}, BEST-DEV-TEST-LOSS:{}'.format(
                best_dev_loss, best_dev_test_loss))

        if counter > 40:
            break

    logger.info('===================[][][][][]====================')
    logger.info('TRAIN: epoch:{}-loss:{}'.format(epoch, train_loss))
    logger.info('DEV: epoch:{}-loss:{}'.format(epoch, dev_loss))
    logger.info('TEST: epoch:{}-loss:{}'.format(epoch, test_loss))
    logger.info('BEST-DEV-LOSS: {}, BEST-DEV-TEST-LOSS:{}'.format(
        best_dev_loss, best_dev_test_loss))
    torch.save(
        best_dev_model, model_save_dir + '/model-{}-{}-{}-{}.pt'.format(
            best_dev_test_loss, args.lr, args.hidden_dim, args.gamma))

    del dataLoader
    del best_dev_model
    del model
    del train_dataset
    del val_dataset
    del test_dataset
class PolicyGradient:
    def __init__(self,
                 state_dim,
                 device='cpu',
                 gamma=0.99,
                 lr=0.01,
                 batch_size=5):
        self.gamma = gamma
        self.policy_net = MLP(state_dim)
        self.optimizer = torch.optim.RMSprop(self.policy_net.parameters(),
                                             lr=lr)
        self.batch_size = batch_size

    def choose_action(self, state):

        state = torch.from_numpy(state).float()
        state = Variable(state)
        probs = self.policy_net(state)
        m = Bernoulli(probs)
        action = m.sample()

        action = action.data.numpy().astype(int)[0]  # 转为标量
        return action

    def update(self, reward_pool, state_pool, action_pool):
        # Discount reward
        running_add = 0
        for i in reversed(range(len(reward_pool))):
            if reward_pool[i] == 0:
                running_add = 0
            else:
                running_add = running_add * self.gamma + reward_pool[i]
                reward_pool[i] = running_add

        # Normalize reward
        reward_mean = np.mean(reward_pool)
        reward_std = np.std(reward_pool)
        for i in range(len(reward_pool)):
            reward_pool[i] = (reward_pool[i] - reward_mean) / reward_std

        # Gradient Desent
        self.optimizer.zero_grad()

        for i in range(len(reward_pool)):
            state = state_pool[i]
            action = Variable(torch.FloatTensor([action_pool[i]]))
            reward = reward_pool[i]

            state = Variable(torch.from_numpy(state).float())
            probs = self.policy_net(state)
            m = Bernoulli(probs)
            loss = -m.log_prob(
                action) * reward  # Negtive score function x reward
            # print(loss)
            loss.backward()
        self.optimizer.step()

    def save_model(self, path):
        torch.save(self.policy_net.state_dict(), path)

    def load_model(self, path):
        self.policy_net.load_state_dict(torch.load(path))
示例#26
0
class DQN:
    def __init__(self, state_dim, action_dim, cfg):
        """

        :param state_dim: About Task
        :param action_dim: About Task
        :param cfg: Config, About DQN setting
        """
        self.device = cfg.device
        self.action_dim = action_dim
        self.gamma = cfg.gamma
        self.frame_idx = 0  # Decay count for epsilon
        self.epsilon = lambda frame_idx: \
            cfg.epsilon_end + \
            (cfg.epsilon_start - cfg.epsilon_end) * \
            math.exp(-1. * frame_idx / cfg.epsilon_decay)
        self.batch_size = cfg.batch_size
        self.q_value_net = MLP(state_dim,
                               action_dim,
                               hidden_dim=cfg.hidden_dim).to(self.device)
        self.target_net = MLP(state_dim, action_dim,
                              hidden_dim=cfg.hidden_dim).to(self.device)
        self.optimizer = optim.Adam(self.q_value_net.parameters(), lr=cfg.lr)
        self.loss = 0
        self.replay_buffer = ReplayBuffer(cfg.capacity)

    def choose_action(self, state):
        # Select actions using e—greedy principle
        self.frame_idx += 1
        if random.random() > self.epsilon(self.frame_idx):
            # Will not track the gradient
            with torch.no_grad():
                # Although Q(s,a) is written in the pseudocode of the original paper,
                # it is actually the value of Q(s) output |A| dimension
                state = torch.tensor([state],
                                     device=self.device,
                                     dtype=torch.float)
                q_value = self.q_value_net(state)

                # output = torch.max(input, dim)
                # dim is the dimension 0/1 of the max function index,
                # 0 is the maximum value of each column,
                # 1 is the maximum value of each row
                # The function will return two tensors,
                # the first tensor is the maximum value of each row;
                # the second tensor is the index of the maximum value of each row.

                # .item(): only one element tensors can be converted to Python scalars
                action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action

    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        # Randomly sample transitions from the replay buffer
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = \
            self.replay_buffer.sample(self.batch_size)
        state_batch = to_tensor_float(state_batch, device=self.device)
        # tensor([1, 2, 3, 4]).unsqueeze(1)  -> tensor([[1],[2],[3],[4]])
        action_batch = torch.tensor(action_batch,
                                    device=self.device).unsqueeze(1)
        reward_batch = to_tensor_float(reward_batch, device=self.device)
        next_state_batch = to_tensor_float(next_state_batch,
                                           device=self.device)
        done_batch = to_tensor_float(done_batch, device=self.device)

        # Calculate Q(s,a) at time t
        # q_t=Q(s_t,a_t)

        # Use index to index the value of a specific position in a dimension
        # a=torch.Tensor([[1,2],[3,4]]),
        # a.gather(1,torch.LongTensor([[0],[1]]))=torch.Tensor([[1],[4]])

        # index action_batch is obtained from the replay buffer
        q_value = self.q_value_net(state_batch).gather(
            dim=1, index=action_batch)  # shape: [32,1]
        # Calculate Q(s,a) at time t+1
        # q_{t+1}=max_a Q(s_t+1,a)

        # .detach():
        # Return a new Variable, which is separated from the current calculation graph,
        # but still points to the storage location of the original variable.
        # The difference is that requires grad is false.
        # The obtained Variable never needs to calculate its gradient and does not have grad.
        #
        # Even if it re-sets its requirements grad to true later,
        # it will not have a gradient grad
        next_q_value = self.target_net(next_state_batch).max(1)[0].detach()
        # For the termination state, the corresponding expected_q_value is equal to reward
        expected_q_value = reward_batch + self.gamma * next_q_value * (
            1 - done_batch)  # shape: 32
        # loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
        # reduce = False,return loss in vector form
        # reduce = True, return loss in scalar form
        # size_average = True,return loss.mean()
        # size_average = False,return loss.sum()
        self.loss = nn.MSELoss()(q_value, expected_q_value.unsqueeze(1))
        # Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
        self.optimizer.zero_grad()
        self.loss.backward()
        # Performs a single optimization step (parameter update).
        self.optimizer.step()

    def save(self, path):
        # Returns a dictionary containing a whole state of the module.
        # Both parameters and persistent buffers (e.g. running averages) are included.
        # Keys are corresponding parameter and buffer names.
        torch.save(self.target_net.state_dict(), path + "dqn_checkpoint.pth")

    def load(self, path):
        self.target_net.load_state_dict(torch.load(path +
                                                   "dqn_checkpoint.pth"))
示例#27
0
class Agent():
    def __init__(self, test=False):
        # device
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else :
            self.device = torch.device('cpu')
        
        self.model = MLP(state_dim=4,action_num=2,hidden_dim=256).to(self.device)  
        if test:
            self.load('./pg_best.cpt')        
        # discounted reward
        self.gamma = 0.99 
        # optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-3)
        # saved rewards and actions
        self.memory = Memory()
        self.tensorboard = TensorboardLogger('./')
    def save(self, save_path):
        print('save model to', save_path)
        torch.save(self.model.state_dict(), save_path)
    def load(self, load_path):
        print('load model from', load_path)
        self.model.load_state_dict(torch.load(load_path))
    def act(self,x,test=False):
        if not test:
            # boring type casting
            x = ((torch.from_numpy(x)).unsqueeze(0)).float().to(self.device)
            # stochastic sample
            action_prob = self.model(x)
            dist = torch.distributions.Categorical(action_prob)
            action = dist.sample()
            # memory log_prob
            self.memory.logprobs.append(dist.log_prob(action))
            return action.item()    
        else :
            self.model.eval()
            x = ((torch.from_numpy(x)).unsqueeze(0)).float().to(self.device)
            with torch.no_grad():
                action_prob = self.model(x)
                # a = np.argmax(action_prob.cpu().numpy())
                dist = torch.distributions.Categorical(action_prob)
                action = dist.sample()
                return action.item()
    def collect_data(self, state, action, reward):
        self.memory.actions.append(action)
        self.memory.rewards.append(torch.tensor(reward))
        self.memory.states.append(state)
    def clear_data(self):
        self.memory.clear_memory()

    def update(self):
        R = 0
        advantage_function = []        
        for t in reversed(range(0, len(self.memory.rewards))):
            R = R * self.gamma + self.memory.rewards[t]
            advantage_function.insert(0, R)

        # turn rewards to pytorch tensor and standardize
        advantage_function = torch.Tensor(advantage_function).to(self.device)
        advantage_function = (advantage_function - advantage_function.mean()) / (advantage_function.std() + np.finfo(np.float32).eps)

        policy_loss = []
        for log_prob, reward in zip(self.memory.logprobs, advantage_function):
            policy_loss.append(-log_prob * reward)
        # Update network weights
        self.optimizer.zero_grad()
        loss = torch.cat(policy_loss).sum()
        loss.backward()
        self.optimizer.step() 
        # boring log
        self.tensorboard.scalar_summary("loss", loss.item())
        self.tensorboard.update()
示例#28
0
class Solver(object):
    def __init__(self, config, train_loader, val_loader):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.use_cuda else 'cpu')
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.episodes_per_epoch = config.episodes_per_epoch
        self.N_way_train = config.N_way_train
        self.N_shot_train = config.N_shot_train
        self.N_query_train = config.N_query_train
        self.M_aug_train = config.M_aug_train
        self.N_way_val = config.N_way_val
        self.N_shot_val = config.N_shot_val
        self.N_query_val = config.N_query_val
        self.M_aug_val = config.M_aug_val
        self.matching_fn = config.matching_fn
        self.nz = config.nz

        self.num_epochs = config.num_epochs
        self.resume_iter = config.resume_iter
        self.lr = config.lr
        self.num_steps_decay = config.num_steps_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.weight_decay = config.weight_decay
        self.exp_name = config.name
        os.makedirs(config.ckp_dir, exist_ok=True)
        self.ckp_dir = os.path.join(config.ckp_dir, self.exp_name)
        os.makedirs(self.ckp_dir, exist_ok=True)
        self.log_interval = config.log_interval
        self.ckp_interval = config.ckp_interval

        self.use_wandb = config.use_wandb

        self.build_model()

    def build_model(self):
        self.cnn = Convnet().to(self.device)
        self.g = Hallucinator(self.nz).to(self.device)
        self.mlp = MLP().to(self.device)
        self.optimizer = torch.optim.AdamW(list(self.cnn.parameters()) +
                                           list(self.g.parameters()) +
                                           list(self.mlp.parameters()),
                                           lr=self.lr,
                                           betas=[self.beta1, self.beta2],
                                           weight_decay=self.weight_decay)

        if self.matching_fn == 'parametric':
            self.parametric = nn.Sequential(nn.Linear(800, 400), nn.ReLU(),
                                            nn.Dropout(),
                                            nn.Linear(400, 1)).to(self.device)
            self.optimizer = torch.optim.AdamW(
                list(self.cnn.parameters()) + list(self.g.parameters()) +
                list(self.mlp.parameters()) +
                list(self.parametric.parameters()),
                lr=self.lr,
                betas=[self.beta1, self.beta2],
                weight_decay=self.weight_decay)

        self.scheduler = StepLR(self.optimizer,
                                step_size=self.num_steps_decay,
                                gamma=0.9)

    def save_checkpoint(self, step):
        state = {
            'cnn': self.cnn.state_dict(),
            'g': self.g.state_dict(),
            'mlp': self.mlp.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }

        if self.matching_fn == 'parametric':
            state['parametric'] = self.parametric.state_dict()

        new_checkpoint_path = os.path.join(self.ckp_dir,
                                           '{}-dhm.pth'.format(step + 1))
        torch.save(state, new_checkpoint_path)
        print('model saved to %s' % new_checkpoint_path)

    def load_checkpoint(self, resume_iter):
        print('Loading the trained models from step {}...'.format(resume_iter))
        new_checkpoint_path = os.path.join(self.ckp_dir,
                                           '{}-dhm.pth'.format(resume_iter))
        state = torch.load(new_checkpoint_path)
        self.cnn.load_state_dict(state['cnn'])
        self.g.load_state_dict(state['g'])
        self.mlp.load_state_dict(state['mlp'])
        self.optimizer.load_state_dict(state['optimizer'])
        if self.matching_fn == 'parametric':
            self.parametric.load_state_dict(state['parametric'])
        print('model loaded from %s' % new_checkpoint_path)

    def train(self):
        criterion = nn.CrossEntropyLoss()

        best_mean = 0
        iteration = 0
        self.sample_idx_val = []
        self.noise_val = []
        for i in range(self.episodes_per_epoch):
            self.sample_idx_val.append(
                torch.tensor([
                    torch.randint(self.N_shot_val * i,
                                  self.N_shot_val * (i + 1),
                                  (self.M_aug_val, )).numpy()
                    for i in range(self.N_way_val)
                ]).reshape(-1))
            self.noise_val.append(
                torch.randn((self.N_way_val * self.M_aug_val, self.nz),
                            device=self.device))

        if self.resume_iter:
            print("resuming step %d ..." % self.resume_iter)
            iteration = self.resume_iter
            self.load_checkpoint(self.resume_iter)
            loss, mean, std = self.eval()
            if mean > best_mean:
                best_mean = mean

        episodic_acc = []

        for ep in range(self.num_epochs):
            self.cnn.train()
            self.g.train()
            self.mlp.train()

            for batch_idx, (data, target) in enumerate(self.train_loader):
                data = data.to(self.device)
                self.optimizer.zero_grad()

                support_input = data[:self.N_way_train *
                                     self.N_shot_train, :, :, :]
                query_input = data[self.N_way_train *
                                   self.N_shot_train:, :, :, :]

                label_encoder = {
                    target[i * self.N_shot_train]: i
                    for i in range(self.N_way_train)
                }
                query_label = torch.cuda.LongTensor([
                    label_encoder[class_name]
                    for class_name in target[self.N_way_train *
                                             self.N_shot_train:]
                ])

                support = self.cnn(support_input)
                queries = self.cnn(query_input)

                sample_idx = torch.tensor([
                    torch.randint(self.N_shot_train * i,
                                  self.N_shot_train * (i + 1),
                                  (self.M_aug_train, )).numpy()
                    for i in range(self.N_way_train)
                ]).reshape(-1)

                sample = support[sample_idx]
                noise = torch.randn(
                    (self.N_way_train * self.M_aug_train, self.nz),
                    device=self.device)

                support_g = self.g(sample,
                                   noise).reshape(self.N_way_train,
                                                  self.M_aug_train, -1)
                support = support.reshape(self.N_way_train, self.N_shot_train,
                                          -1)

                support_aug = torch.cat([support, support_g], dim=1)
                support_aug = support_aug.reshape(
                    self.N_way_train * (self.N_shot_train + self.M_aug_train),
                    -1)

                prototypes = self.mlp(support_aug)
                prototypes = prototypes.reshape(
                    self.N_way_train, self.N_shot_train + self.M_aug_train,
                    -1).mean(dim=1)
                queries = self.mlp(queries)

                if self.matching_fn == 'parametric':
                    distances = pairwise_distances(queries, prototypes,
                                                   self.matching_fn,
                                                   self.parametric)

                else:
                    distances = pairwise_distances(queries, prototypes,
                                                   self.matching_fn)

                loss = criterion(-distances, query_label)
                loss.backward()
                self.optimizer.step()

                y_pred = (-distances).softmax(dim=1).max(1, keepdim=True)[1]
                episodic_acc.append(
                    1. * y_pred.eq(query_label.view_as(y_pred)).sum().item() /
                    len(query_label))

                if (iteration + 1) % self.log_interval == 0:
                    episodic_acc = np.array(episodic_acc)
                    mean = episodic_acc.mean()
                    std = episodic_acc.std()

                    print(
                        'Epoch: {:3d} [{:d}/{:d}]\tIteration: {:5d}\tLoss: {:.6f}\tAccuracy: {:.2f} +- {:.2f} %'
                        .format(
                            ep, (batch_idx + 1), len(self.train_loader),
                            iteration + 1, loss.item(), mean * 100,
                            1.96 * std / (self.log_interval)**(1 / 2) * 100))

                    if self.use_wandb:
                        import wandb
                        wandb.log(
                            {
                                "loss":
                                loss.item(),
                                "acc_mean":
                                mean * 100,
                                "acc_ci":
                                1.96 * std /
                                (self.log_interval)**(1 / 2) * 100,
                                'lr':
                                self.optimizer.param_groups[0]['lr']
                            },
                            step=iteration + 1)

                    episodic_acc = []

                if (iteration + 1) % self.ckp_interval == 0:
                    loss, mean, std = self.eval()
                    if mean > best_mean:
                        best_mean = mean
                        self.save_checkpoint(iteration)
                        if self.use_wandb:
                            wandb.run.summary[
                                "best_accuracy"] = best_mean * 100

                    if self.use_wandb:
                        import wandb
                        wandb.log(
                            {
                                "val_loss": loss,
                                "val_acc_mean": mean * 100,
                                "val_acc_ci": 1.96 * std / (600)**(1 / 2) * 100
                            },
                            step=iteration + 1,
                            commit=False)

                iteration += 1

            self.scheduler.step()
        self.save_checkpoint(iteration)

    def eval(self):
        criterion = nn.CrossEntropyLoss()
        self.cnn.eval()
        self.g.eval()
        self.mlp.eval()
        episodic_acc = []
        loss = []

        with torch.no_grad():
            for b_idx, (data, target) in enumerate(self.val_loader):
                data = data.to(self.device)
                support_input = data[:self.N_way_val *
                                     self.N_shot_val, :, :, :]
                query_input = data[self.N_way_val * self.N_shot_val:, :, :, :]

                label_encoder = {
                    target[i * self.N_shot_val]: i
                    for i in range(self.N_way_val)
                }
                query_label = torch.cuda.LongTensor([
                    label_encoder[class_name]
                    for class_name in target[self.N_way_val * self.N_shot_val:]
                ])

                support = self.cnn(support_input)
                queries = self.cnn(query_input)

                sample_idx = self.sample_idx_val[b_idx]
                sample = support[sample_idx]

                noise = self.noise_val[b_idx]

                support_g = self.g(sample,
                                   noise).reshape(self.N_way_val,
                                                  self.M_aug_val, -1)
                support = support.reshape(self.N_way_val, self.N_shot_val, -1)

                support_aug = torch.cat([support, support_g], dim=1)
                support_aug = support_aug.reshape(
                    self.N_way_val * (self.N_shot_val + self.M_aug_val), -1)

                prototypes = self.mlp(support_aug)
                prototypes = prototypes.reshape(
                    self.N_way_val, self.N_shot_val + self.M_aug_val,
                    -1).mean(dim=1)
                queries = self.mlp(queries)

                if self.matching_fn == 'parametric':
                    distances = pairwise_distances(queries, prototypes,
                                                   self.matching_fn,
                                                   self.parametric)
                else:
                    distances = pairwise_distances(queries, prototypes,
                                                   self.matching_fn)

                loss.append(criterion(-distances, query_label).item())
                y_pred = (-distances).softmax(dim=1).max(1, keepdim=True)[1]
                episodic_acc.append(
                    1. * y_pred.eq(query_label.view_as(y_pred)).sum().item() /
                    len(query_label))

        loss = np.array(loss)
        episodic_acc = np.array(episodic_acc)
        loss = loss.mean()
        mean = episodic_acc.mean()
        std = episodic_acc.std()

        print('\nLoss: {:.6f}\tAccuracy: {:.2f} +- {:.2f} %\n'.format(
            loss, mean * 100, 1.96 * std / (600)**(1 / 2) * 100))

        return loss, mean, std