Ejemplo n.º 1
0
def evaluate_legacy_model(weight_files,
                          num_game,
                          seed,
                          bomb,
                          num_run=1,
                          verbose=True):
    # model_lockers = []
    # greedy_extra = 0
    agents = []
    num_player = len(weight_files)
    assert num_player > 1, "1 weight file per player"

    for weight_file in weight_files:
        if verbose:
            print("evaluating: %s\n\tfor %dx%d games" %
                  (weight_file, num_run, num_game))
        if "sad" in weight_file or "aux" in weight_file:
            sad = True
        else:
            sad = False

        device = "cuda:0"

        state_dict = torch.load(weight_file)
        input_dim = state_dict["net.0.weight"].size()[1]
        hid_dim = 512
        output_dim = state_dict["fc_a.weight"].size()[0]

        agent = r2d2.R2D2Agent(False, 3, 0.999, 0.9, device, input_dim,
                               hid_dim, output_dim, 2, 5, False).to(device)
        utils.load_weight(agent.online_net, weight_file, device)
        agents.append(agent)

    scores = []
    perfect = 0
    for i in range(num_run):
        _, _, score, p = evaluate(
            agents,
            num_game,
            num_game * i + seed,
            bomb,
            0,
            sad,
        )
        scores.extend(score)
        perfect += p

    mean = np.mean(scores)
    sem = np.std(scores) / np.sqrt(len(scores))
    perfect_rate = perfect / (num_game * num_run)
    if verbose:
        print("score: %f +/- %f" % (mean, sem), "; perfect: ", perfect_rate)
    return mean, sem, perfect_rate
Ejemplo n.º 2
0
def load_op_model(method, idx1, idx2):
    """load op models, op models was trained only for 2 player
    """
    root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    # assume model saved in root/models/op
    folder = os.path.join(root, "models", "op", method)
    agents = []
    for idx in [idx1, idx2]:
        if idx >= 0 and idx < 3:
            num_fc = 1
            skip_connect = False
        elif idx >= 3 and idx < 6:
            num_fc = 1
            skip_connect = True
        elif idx >= 6 and idx < 9:
            num_fc = 2
            skip_connect = False
        else:
            num_fc = 2
            skip_connect = True
        weight_file = os.path.join(folder, f"M{idx}.pthw")
        if not os.path.exists(weight_file):
            print(f"Cannot find weight at: {weight_file}")
            assert False

        device = "cuda:0"
        state_dict = torch.load(weight_file)
        input_dim = state_dict["net.0.weight"].size()[1]
        hid_dim = 512
        output_dim = state_dict["fc_a.weight"].size()[0]
        agent = r2d2.R2D2Agent(
            False,
            3,
            0.999,
            0.9,
            device,
            input_dim,
            hid_dim,
            output_dim,
            2,
            5,
            False,
            num_fc_layer=num_fc,
            skip_connect=skip_connect,
        ).to(device)
        utils.load_weight(agent.online_net, weight_file, device)
        agents.append(agent)
    return agents
Ejemplo n.º 3
0
def setup(data_folder):
    np.random.seed(0)
    torch.cuda.manual_seed_all(0)
    torch.manual_seed(0)

    codec = get_encoder()

    dataset = NewsDataset(path=data_folder,
                          ctx_length=128,
                          codec=codec,
                          start_from_zero=True)

    config = GPT2Config()
    model = GPT2LMHeadModel(config)
    if not os.path.exists('gpt2-pytorch_model.bin'):
        print("Downloading GPT-2 checkpoint...")
        url = 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin'
        r = requests.get(url, allow_redirects=True)
        open('gpt2-pytorch_model.bin', 'wb').write(r.content)

    model = load_weight(
        model, torch.load('gpt2-pytorch_model.bin', map_location=device))
    model = model.to(device)
    model.eval()
    return codec, model, dataset, config
Ejemplo n.º 4
0
def predict(model_type, finetune_dataset, input_path, output_path,
            probability_output, batch_size, gpu=True):
    model = fastsal.fastsal(pretrain_mode=False, model_type=model_type)
    state_dict, opt_state = load_weight('weights/{}_{}.pth'.format(finetune_dataset, model_type), remove_decoder=False)
    model.load_state_dict(state_dict)
    if gpu:
        model.cuda()
    simple_data = img_dataset(input_path, output_path)
    simple_loader = DataLoader(simple_data, batch_size=batch_size, shuffle=False, num_workers=4)
    for x, original_size_list, output_path_list in simple_loader:
        if gpu:
            x = x.float().cuda()
        y = model(x)
        if not probability_output: y = nn.Sigmoid()(y)
        if gpu:
            y = y.detach().cpu()
        y = y.numpy()
        for i, prediction in enumerate(y[:, 0, :, :]):
            img_output_path = output_path_list[i]
            original_size = original_size_list[i].numpy()
            print(img_output_path)
            if not probability_output:
                img_data = post_process_png(prediction, original_size)
                cv2.imwrite(img_output_path, img_data)
            else:
                img_data = post_process_probability2(prediction, original_size)
                np.save(img_output_path.split('.')[0], img_data)
Ejemplo n.º 5
0
def main():
    # random seed
    seed = 1234
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # load dataset
    if args.dataset[0] == 'deepfashion':
        ds = pd.read_csv('./Anno/df_info.csv')
        from dataset import DeepFashionDataset as DataManager
    elif args.dataset[0] == 'fld':
        ds = pd.read_csv('./Anno/fld_info.csv')
        from dataset import FLDDataset as DataManager
    else:
        raise ValueError

    print('dataset : %s' % (args.dataset[0]))
    if not args.evaluate:
        train_dm = DataManager(ds[ds['evaluation_status'] == 'train'],
                               root=args.root)
        train_dl = DataLoader(train_dm,
                              batch_size=args.batchsize,
                              shuffle=True)

        if os.path.exists('models') is False:
            os.makedirs('models')
    test_dm = DataManager(ds[ds['evaluation_status'] == 'test'],
                          root=args.root)
    test_dl = DataLoader(test_dm, batch_size=args.batchsize, shuffle=False)

    val_dm = DataManager(ds[ds['evaluation_status'] == 'val'], root=args.root)
    val_dl = DataLoader(val_dm, batch_size=args.batchsize, shuffle=True)

    # Load model
    print("Load the model...")
    net = torch.nn.DataParallel(StructNet().cuda())
    if not args.weight_file == None:
        weights = torch.load(args.weight_file)
        if args.update_weight:
            weights = utils.load_weight(net, weights)
        net.load_state_dict(weights)

    # evaluate only
    if args.evaluate:
        print("Evaluation only")
        test(net, test_dl, -1)
        return

    # learning parameters
    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, 0.1)

    print('Start training')
    for epoch in range(args.epoch):

        train(net, optimizer, train_dl, epoch)
        lr_scheduler.step()
        test(net, val_dl, epoch)
Ejemplo n.º 6
0
def main(model_type, batch_size, dataset_name, dataset_path, size,
         width_bigger, pretrain_path, save_path, probability_output):
    # Datasets for SALICON
    if dataset_name == 'salicon':
        ds_validate = dataset.Salicon(dataset_path,
                                      mode='test',
                                      type=['vgg_img'],
                                      size=(size, ))
    elif dataset_name == 'mit300':
        ds_validate = mit300.dataset(dataset_path,
                                     type=('vgg_img'),
                                     size=(size, ))
        ds_validate.renew_list(width_bigger=width_bigger)
    elif dataset_name == 'mit1003':
        ds_validate = mit1003.dataset(dataset_path,
                                      type=('vgg_img'),
                                      size=(size, ))
        ds_validate.renew_list(width_bigger=width_bigger)
    elif dataset_name == 'dhf1k':
        ds_validate = dhf1k.dataset(dataset_path,
                                    mode='test',
                                    type=('vgg_img'),
                                    size=(size, ))

    file_list = ds_validate.list_names

    if pretrain_path:
        state_dict, opt_state = load_weight(pretrain_path,
                                            remove_decoder=False)
    else:
        print('please specify trained models.')
        exit()

    model = student_teacher.salgan_teacher_student(
        False, model_type, use_probability_gt=probability_output)
    model.student_net.load_state_dict(state_dict)
    model.cuda()
    dataloader = {
        'val':
        DataLoader(ds_validate,
                   batch_size=batch_size,
                   shuffle=False,
                   num_workers=4)
    }

    print('--------------------------------------------->>>>>>')
    model.eval()
    with t.no_grad():
        train_one(model, dataloader, file_list, 'val', save_path,
                  probability_output)
    print('--------------------------------------------->>>>>>')
Ejemplo n.º 7
0
def load_sad_model(weight_files):
    agents = []
    for weight_file in weight_files:
        if verbose:
            print(
                "evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game)
            )
        if "sad" in weight_file or "aux" in weight_file:
            sad = True
        else:
            sad = False

        device = "cuda:0"
        state_dict = torch.load(weight_file)
        input_dim = state_dict["net.0.weight"].size()[1]
        hid_dim = 512
        output_dim = state_dict["fc_a.weight"].size()[0]

        agent = r2d2.R2D2Agent(
            False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False
        ).to(device)
        utils.load_weight(agent.online_net, weight_file, device)
        agents.append(agent)
    return agents
Ejemplo n.º 8
0
def setup(n_enc_layer=1):
    np.random.seed(0)
    torch.cuda.manual_seed_all(0)
    torch.manual_seed(0)

    codec = get_encoder()
    config = GPT2Config(n_enc_layer=n_enc_layer)
    model = GPT2LMHeadModel(config)
    if not os.path.exists('../gpt2-pytorch_model.bin'):
        print("Downloading GPT-2 checkpoint...")
        url = 'https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin'
        r = requests.get(url, allow_redirects=True)
        open('../gpt2-pytorch_model.bin', 'wb').write(r.content)

    model = load_weight(
        model, torch.load('../gpt2-pytorch_model.bin', map_location=device))
    model = model.to(device)
    return codec, model, config
Ejemplo n.º 9
0
def exportToONNX(weights, model_type, output):

    # Fastsal - Type Coco A
    model = fastsal.fastsal(pretrain_mode=False, model_type=model_type)
    state_dict, opt_state = load_weight(weights, remove_decoder=False)
    model.load_state_dict(state_dict)
    model.eval()
    torch_out = model(x)
    torch.onnx.export(
        model,  # model being run
        x,  # model input (or a tuple for multiple inputs)
        output,  # where to save the model (can be a file or file-like object)
        export_params=
        True,  # store the trained parameter weights inside the model file
        opset_version=11,  # the ONNX version to export the model to
        do_constant_folding=
        True,  # whether to execute constant folding for optimization
        input_names=["input"],  # the model's input names
        output_names=["output"],  # the model's output names
    )
Ejemplo n.º 10
0
def start_eval(model_type,
               batch_size,
               dataset_name,
               dataset_path,
               model_path=None):
    if dataset_name == 'salicon':
        ds_validate = dataset.Salicon(dataset_path,
                                      mode='val',
                                      type=['vgg_img', 'sal_img', 'fixation'],
                                      size=[(192, 256), (480, 640),
                                            (480, 640)])
    elif dataset_name == 'mit1003':
        ds_validate = mit1003.dataset(dataset_path,
                                      type=['vgg_img', 'sal_img', 'fixation'],
                                      size=[(192, 256), None, None])

    if model_path:
        state_dict, opt_state = load_weight(model_path, remove_decoder=False)
    else:
        print('please specify trained models.')
        exit()

    model = student_teacher.salgan_teacher_student(False, model_type)
    model.student_net.load_state_dict(state_dict)
    model.cuda()

    # model.generator.load_state_dict(state_dict['state_dict'])
    # Dataloaders
    dataloader = {
        'val':
        DataLoader(ds_validate,
                   batch_size=batch_size,
                   shuffle=False,
                   num_workers=4)
    }

    model.eval()
    with t.no_grad():
        train_one(model, dataloader, 'val')
Ejemplo n.º 11
0
def start_train(batch_size, dataset_name, dataset_path, teacher_path, direct,
                model_name):
    if dataset_name == 'salicon':
        dataloader = salicon_data(batch_size, dataset_path)
    elif dataset_name == 'coco':
        dataloader = coco_data(batch_size, dataset_path)

    model = student_teacher.salgan_teacher_student(True, 'C', teacher_path)
    model.cuda()

    lr = 0.01
    lr_decay = 0.1
    optimizer = model.get_optimizer(lr)
    smallest_val = None
    best_epoch = None

    for epoch in range(0, 100, 1):
        model.train()
        loss_train, model = train_one(model, dataloader, optimizer, 'train')
        print('{} loss train {}, lr {}'.format(epoch, loss_train, lr))
        print('--------------------------------------------->>>>>>')
        model.eval()
        loss_val, model = train_one(model, dataloader, optimizer, 'val')
        print('--------------------------------------------->>>>>>')
        print('{} loss val {}'.format(epoch, loss_val))

        smallest_val, best_epoch, model, optimizer = save_weight(
            smallest_val, best_epoch, loss_val, epoch, direct, model_name,
            model, optimizer)
        if epoch == 15 or epoch == 30 or epoch == 60:
            path = '{}/{}/{}_{:f}.pth'.format(direct, model_name, best_epoch,
                                              smallest_val)
            state_dict, opt_state = load_weight(path, remove_decoder=False)
            model.student_net.load_state_dict(state_dict)
            # optimizer.load_state_dict(state_dict['optimizer'])
            for param_group in optimizer.param_groups:
                param_group['lr'] *= lr_decay
            lr = lr * lr_decay
Ejemplo n.º 12
0
def run(arg):
    torch.manual_seed(7)
    np.random.seed(7)
    print("lr %f, epoch_num %d, decay_rate %f gamma %f" % (arg.lr, arg.epochs, arg.decay, arg.gamma))

    print("====>Loading data")
    if arg.dataset == 'mnist':
       train_data = get_mnist_dataset("train", arg.batch_size)
       G_output_dim = config.mnist_G_output_dim
       D_input_dim = config.mnist_D_input_dim
    elif arg.dataset == 'cifar':
        train_data = get_cifar_dataset("train", arg.batch_size)
        G_output_dim = config.cifar_G_output_dim
        D_input_dim = config.cifar_D_input_dim

    print("====>Building model")
    if arg.net == "DCGAN":
        g_net = DC_Generator(config.G_input_dim, config.num_filters, G_output_dim, verbose=False)
        d_net = DC_Discriminator(D_input_dim, config.num_filters[::-1], config.D_output_dim, verbose=False)

        # g_net = Generator()
        # d_net = Discriminator()

        g_net = g_net.to(config.device)
        d_net = d_net.to(config.device)

    g_optimizer = optim.Adam(g_net.parameters(), lr=arg.lr, betas=(0.5, 0.9))
    d_optimizer = optim.Adam(d_net.parameters(), lr=arg.lr, betas=(0.5, 0.9))

    if arg.mul_gpu:
        g_net = nn.DataParallel(g_net)
        d_net = nn.DataParallel(d_net)

    if arg.checkpoint is not None:
        print("load pre train model")
        g_pretrained_dict = torch.load(os.path.join(config.checkpoint_path, arg.dataset + "_" +
                                                      arg.net + "_g_net_" + arg.checkpoint + '.pth'))
        d_pretrained_dict = torch.load(os.path.join(config.checkpoint_path, arg.dataset + "_" +
                                                    arg.net + "_d_net_" + arg.checkpoint + '.pth'))
        g_net = load_weight(g_pretrained_dict, g_net)
        d_net = load_weight(d_pretrained_dict, d_net)

    # 日志系统
    logger = get_logger()

    criterion = nn.BCELoss()

    print('Total params: %.2fM' % ((sum(p.numel() for p in g_net.parameters()) +
                                    sum(p.numel() for p in d_net.parameters())) / 1000000.0))
    print("start training: ", datetime.now())
    start_epoch = 0

    fix_noise = torch.autograd.Variable(torch.randn(config.nrow ** 2, config.G_input_dim).view(-1, config.G_input_dim,
                                                                                               1, 1).cuda())
    g_losses = []
    d_losses = []
    for epoch in range(start_epoch, arg.epochs):
        prev_time = datetime.now()

        g_loss, d_loss = train(train_data, g_net, d_net, criterion, g_optimizer, d_optimizer, epoch, logger)
        now_time = datetime.now()
        time_str = count_time(prev_time, now_time)
        print("train: current (%d/%d) batch g_loss is %f d_loss is %f time "
              "is %s" % (epoch, arg.epochs,  g_loss, d_loss, time_str))

        g_losses.append(g_loss)
        d_losses.append(d_loss)

        plot_loss(d_losses, g_losses, epoch, arg.net, arg.dataset)
        plot_sample(g_net, fix_noise, epoch, net_name=arg.net, dataset_name=arg.dataset)

        if epoch % 2 == 0:
            save_checkpoint(arg.dataset, arg.net,  epoch, g_net, d_net)

    save_checkpoint(arg.dataset, arg.net, arg.epochs, g_net, d_net)
Ejemplo n.º 13
0
def evaluate_legacy_model(
    weight_files,
    num_game,
    seed,
    bomb,
    agent_args,
    args,
    num_run=1,
    gen_cross_play=False,
    verbose=True,
):
    agents = []
    num_player = len(weight_files)
    assert num_player > 1, "1 weight file per player"

    env_sad = False
    for i, weight_file in enumerate(weight_files):
        if verbose:
            print(
                "evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game)
            )
        if "sad" in weight_file:
            sad = True
            env_sad = True
        else:
            sad = False

        device = "cuda:0"

        state_dict = torch.load(weight_file)
        input_dim = state_dict["net.0.weight"].size()[1]
        output_dim = state_dict["fc_a.weight"].size()[0]

        if gen_cross_play:
            agent_name = weight_file.split("/")[-1].split(".")[0]

            with open(f"{args.weight_1_dir}/{agent_name}.txt", "r") as f:
                agent_args = {**json.load(f)}
        else:
            learnable_pretrain = True

            if i == 0:
                learnable_agent_name = agent_args["load_learnable_model"]
                if learnable_agent_name != "":
                    agent_args_file = f"{learnable_agent_name[:-4]}txt"
                else:
                    learnable_pretrain = False
            else:
                agent_args_file = f"{weight_file[:-4]}txt"

            if learnable_pretrain == True:
                with open(agent_args_file, "r") as f:
                    agent_args = {**json.load(f)}

        rnn_type = agent_args["rnn_type"]
        rnn_hid_dim = agent_args["rnn_hid_dim"]
        num_fflayer = agent_args["num_fflayer"]
        num_rnn_layer = agent_args["num_rnn_layer"]

        if rnn_type == "lstm":
            import r2d2_lstm as r2d2
        elif rnn_type == "gru":
            import r2d2_gru as r2d2

        agent = r2d2.R2D2Agent(
            False,
            3,
            0.999,
            0.9,
            device,
            input_dim,
            rnn_hid_dim,
            output_dim,
            num_fflayer,
            num_rnn_layer,
            5,
            False,
            sad=sad,
        ).to(device)

        utils.load_weight(agent.online_net, weight_file, device)
        agents.append(agent)

    scores = []
    perfect = 0
    for i in range(num_run):
        if args.is_rand:
            random.shuffle(agents)

        _, _, score, p = evaluate(
            agents,
            num_game,
            num_game * i + seed,
            bomb,
            0,
            env_sad,
        )
        scores.extend(score)
        perfect += p

    mean = np.mean(scores)
    sem = np.std(scores) / np.sqrt(len(scores))
    perfect_rate = perfect / (num_game * num_run)
    if verbose:
        print("score: %f +/- %f" % (mean, sem), "; perfect: ", perfect_rate)
    return mean, sem, perfect_rate
Ejemplo n.º 14
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @File  : convert_weight.py
# @Author: ruixi L
# @Date  : 2019/11/13

import tensorflow as tf
from model import Mymodel
from utils import load_weight

weight_path = r'./model.pb'
inputs = tf.placeholder(tf.float32, shape=[None, 784], name='x')
model = Mymodel(class_num=10)
logits = model.forward(inputs)
name_list = tf.trainable_variables()
with tf.Session() as sess:
    load_op = load_weight(name_list, weight_path)
    for i in range(len(load_op)):
        sess.run(tf.assign(name_list[i], load_op[i]))
    tf.train.Saver().save(sess, './checkpoint/mymodel')
    print('权重转换完成')
Ejemplo n.º 15
0
        args.train_device,
        learnable_games[0].feature_size(),
        rnn_hid_dim,
        learnable_games[0].num_action(),
        num_fflayer,
        num_rnn_layer,
        args.hand_size,
        False,
        sad=learnable_sad,
    )

    learnable_agent.sync_target_with_online()

    if args.load_learnable_model:
        print("*****loading pretrained model for learnable agent *****")
        utils.load_weight(learnable_agent.online_net,
                          args.load_learnable_model, args.train_device)
        print("*****done*****")
    if args.resume_cont_training:
        print("***** resuming continual training ... ")
        learnable_agent_ckpts = glob.glob(f"{args.save_dir}/*_zero_shot.pthw")
        learnable_agent_ckpts.sort(key=os.path.getmtime)
        print("restoring from ... ", learnable_agent_ckpts[-1])
        utils.load_weight(learnable_agent.online_net,
                          learnable_agent_ckpts[-1], args.train_device)
        epoch_restore = int(learnable_agent_ckpts[-1].split("/")[-1].split(".")
                            [0].split("_")[1][5:])
        print("epoch restore is ... ", epoch_restore)

    learnable_agent = learnable_agent.to(args.train_device)
    print(learnable_agent)
Ejemplo n.º 16
0
        args.multi_step,
        args.gamma,
        args.eta,
        args.train_device,
        games[0].feature_size(),
        args.rnn_hid_dim,
        games[0].num_action(),
        args.num_lstm_layer,
        args.hand_size,
        False,  # uniform priority
    )
    agent.sync_target_with_online()

    if args.load_model:
        print("*****loading pretrained model*****")
        utils.load_weight(agent.online_net, args.load_model, args.train_device)
        print("*****done*****")

    agent = agent.to(args.train_device)
    optim = torch.optim.Adam(agent.online_net.parameters(),
                             lr=args.lr,
                             eps=args.eps)
    print(agent)
    eval_agent = agent.clone(args.train_device, {"vdn": False})

    replay_buffer = rela.RNNPrioritizedReplay(
        args.replay_buffer_size,
        args.seed,
        args.priority_exponent,
        args.priority_weight,
        args.prefetch,
Ejemplo n.º 17
0
import model.fastSal as fastsal
from utils import load_weight
import torch

if __name__ == '__main__':
    coco_c = 'weights/coco_C.pth'  # coco_C
    coco_a = 'weights/coco_A.pth'  # coco_A
    salicon_c = 'weights/salicon_C.pth'  # salicon_C
    salicon_a = 'weights/salicon_A.pth'  # coco_A

    x = torch.zeros((10, 3, 192, 256))

    model = fastsal.fastsal(pretrain_mode=False, model_type='A')
    state_dict, opt_state = load_weight(coco_a, remove_decoder=False)
    model.load_state_dict(state_dict)
    y = model(x)
    print(y.shape)

    model = fastsal.fastsal(pretrain_mode=False, model_type='A')
    state_dict, opt_state = load_weight(salicon_a, remove_decoder=False)
    model.load_state_dict(state_dict)
    y = model(x)
    print(y.shape)

    model = fastsal.fastsal(pretrain_mode=False, model_type='C')
    state_dict, opt_state = load_weight(coco_c, remove_decoder=False)
    model.load_state_dict(state_dict)
    y = model(x)
    print(y.shape)

    model = fastsal.fastsal(pretrain_mode=False, model_type='C')
Ejemplo n.º 18
0
def main():
    # random seed
    seed = 1234
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    from dataset_df2_loader import DeepFashionDataset as DataManager
    with open("./data/train/deepfashion2.json", 'r') as infile:
        ds = json.load(infile)
        ds = ds['annotations'][0:5]

    print("dataset", len(ds), args.batchsize, args.epoch)

    print('dataset : %s' % (args.dataset[0]))
    if not args.evaluate:
        train_dm = DataManager(ds, root=args.root)
        train_dl = DataLoader(train_dm,
                              batch_size=args.batchsize,
                              shuffle=True)

        if os.path.exists('models') is False:
            os.makedirs('models')

    with open("./data/validation/deepfashion2_datafile_8.json", 'r') as infile:
        test_data = json.load(infile)

    test_dm = DataManager(
        test_data['annotations'][0:5],
        root=
        "/media/chintu/bharath_ext_hdd/Bharath/Segmentation/Landmark detection/GLE_FLD-master/data/validation/image/"
    )
    test_dl = DataLoader(test_dm, batch_size=args.batchsize, shuffle=False)

    # Load model
    print("Load the model...")
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("device:", device)

    net = torch.nn.DataParallel(Network(dataset=args.dataset,
                                        flag=args.glem)).to(device)
    if not args.weight_file == None:
        weights = torch.load(args.weight_file)
        if args.update_weight:
            weights = utils.load_weight(net, weights)
        net.load_state_dict(weights)

    # evaluate only
    if args.evaluate:
        print("Evaluation only")
        test(net, test_dl, 0)
        return

    # learning parameters
    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.1)

    print('Start training')
    for epoch in range(args.epoch):
        lr_scheduler.step()
        train(net, optimizer, train_dl, epoch)
        test(net, test_dl, epoch)
Ejemplo n.º 19
0
def start_train(model_type,
                batch_size,
                dataset_name,
                dataset_path,
                teacher_path,
                direct,
                model_name,
                pretrain_path=None,
                pseudo_path=None):
    if dataset_name == 'salicon':
        dataloader = salicon_data(batch_size, dataset_path)
        use_gt = True
        use_pseudo_gt = False
        use_teacher = True
        use_probability_gt = False
    elif dataset_name == 'coco':
        dataloader = coco_data(batch_size,
                               dataset_path,
                               input_type=['npy_img'],
                               pseudo_path=pseudo_path)
        use_gt = False
        use_pseudo_gt = True
        use_teacher = False
        use_probability_gt = True

    if pretrain_path:
        state_dict, opt_state = load_weight(pretrain_path,
                                            remove_decoder=False)
    else:
        state_dict = None

    model = student_teacher.salgan_teacher_student(
        False,
        model_type,
        teacher_path,
        state_dict,
        use_gt=use_gt,
        use_teacher=use_teacher,
        use_probability_gt=use_probability_gt,
        use_pseudo_gt=use_pseudo_gt)
    model.cuda()

    lr = 0.01
    lr_decay = 0.1
    optimizer = model.get_optimizer(lr)
    smallest_val = None
    best_epoch = None

    for epoch in range(0, 100, 1):
        #with t.no_grad():
        #    metrics = get_saliency_metrics(dataloader['metric'], model, N=100)
        model.train()
        loss_train, model = train_one(model,
                                      dataloader,
                                      optimizer,
                                      'train',
                                      use_gt=use_gt,
                                      use_teacher=use_teacher,
                                      use_pseudo_gt=use_pseudo_gt)
        print('{} loss train {}, lr {}'.format(epoch, loss_train, lr))
        print('--------------------------------------------->>>>>>')
        model.eval()
        loss_val, model = train_one(model,
                                    dataloader,
                                    optimizer,
                                    'val',
                                    use_gt=use_gt,
                                    use_teacher=use_teacher,
                                    use_pseudo_gt=use_pseudo_gt)
        print('--------------------------------------------->>>>>>')
        print('{} loss val {}'.format(epoch, loss_val))

        smallest_val, best_epoch, model, optimizer = save_weight(
            smallest_val, best_epoch, loss_val, epoch, direct, model_name,
            model, optimizer)
        if epoch == 15 or epoch == 30 or epoch == 60:
            path = '{}/{}/{}_{:f}.pth'.format(direct, model_name, best_epoch,
                                              smallest_val)
            state_dict, opt_state = load_weight(path, remove_decoder=False)
            model.student_net.load_state_dict(state_dict)
            for param_group in optimizer.param_groups:
                param_group['lr'] *= lr_decay
            lr = lr * lr_decay
Ejemplo n.º 20
0
        s = obs["s"].unsqueeze(0)
        assert s.size(2) == self.in_dim

        x = self.net(s)
        o, (h, c) = self.lstm(x, (h0, c0))
        a = self.fc_a(o).squeeze(0)

        return {
            "a": a,
            "h0": h.transpose(0, 1).contiguous(),
            "c0": c.transpose(0, 1).contiguous(),
        }


## main program ##
parser = argparse.ArgumentParser(description="")
parser.add_argument("--model", type=str, default=None)
args = parser.parse_args()

device = "cuda"
state_dict = torch.load(args.model)
in_dim = state_dict["net.0.weight"].size()[1]
out_dim = state_dict["fc_a.weight"].size()[0]

print("after loading model")
search_model = LSTMNet(device, in_dim, 512, out_dim, 2, 5)
utils.load_weight(search_model, args.model, device)
save_path = args.model.rsplit(".", 1)[0] + ".sparta"
print("saving model to:", save_path)
torch.jit.save(search_model, save_path)
Ejemplo n.º 21
0
def text_generator(state_dict):
    parser = argparse.ArgumentParser()
    parser.add_argument("--text", type=str, required=True)
    parser.add_argument("--quiet", type=bool, default=False)
    parser.add_argument("--nsamples", type=int, default=1)
    parser.add_argument('--unconditional',
                        action='store_true',
                        help='If true, unconditional generation.')
    parser.add_argument("--batch_size", type=int, default=-1)
    parser.add_argument("--length", type=int, default=-1)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--top_k", type=int, default=40)
    args = parser.parse_args()

    if args.quiet is False:
        print(args)

    if args.batch_size == -1:
        args.batch_size = 1
    assert args.nsamples % args.batch_size == 0

    seed = random.randint(0, 2147483647)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load Model
    enc = get_encoder()
    config = GPT2Config()
    model = GPT2LMHeadModel(config)
    model = load_weight(model, state_dict)
    model.to(device)
    model.eval()

    if args.length == -1:
        args.length = config.n_ctx // 2
    elif args.length > config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         config.n_ctx)

    print(args.text)
    context_tokens = enc.encode(args.text)

    generated = 0
    for _ in range(args.nsamples // args.batch_size):
        out = sample_sequence(
            model=model,
            length=args.length,
            context=context_tokens if not args.unconditional else None,
            start_token=enc.encoder['<|endoftext|>']
            if args.unconditional else None,
            batch_size=args.batch_size,
            temperature=args.temperature,
            top_k=args.top_k,
            device=device)
        out = out[:, len(context_tokens):].tolist()
        for i in range(args.batch_size):
            generated += 1
            text = enc.decode(out[i])
            if args.quiet is False:
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
            print(text)