Ejemplo n.º 1
0
def build_dgi_dataset(args):
    # Expected key for each playthrough:
    # game, step, action, graph_local, graph_seen, graph_full,
    playthroughs = (json.loads(line.rstrip(",\n")) for line in open(args.input)
                    if len(line.strip()) > 1)

    graph_dataset = GraphDataset()
    dataset = []
    for example in playthroughs:
        # For each data point we want the following 3 keys:
        # game, step, graph
        dataset.append({
            "game":
            example["game"],
            "step":
            example["step"],
            "graph":
            graph_dataset.compress(example["graph_{}".format(
                args.graph_type)]),
        })

    if args.output is None:
        args.output = os.path.splitext(args.input)[0] + ".dgi.{}.json".format(
            args.graph_type)

    data = {
        "graph_index": graph_dataset.dumps(),
        "examples": dataset,
    }
    with open(args.output, "w") as f:
        json.dump(data, f)

    if args.verbose:
        print("This dataset has {:,} datapoints.".format(len(dataset)))
Ejemplo n.º 2
0
def compress_command_generation_dataset(args):
    # Expected key for each playthrough:
    # game, step, observation, previous_action, target_commands, previous_graph_seen, graph_seen
    playthroughs = (json.loads(line.rstrip(",\n")) for line in open(args.input) if len(line.strip()) > 1)

    graph_dataset = GraphDataset()
    dataset = []
    for example in tqdm(playthroughs):
        previous_graph_seen = graph_dataset.compress(example["previous_graph_seen"])
        target_commands = example["target_commands"]

        # For each data point we want the following 6 keys:
        # game, step, observation, previous_action, target_commands, previous_graph_seen
        dataset.append({
            "game": example["game"],
            "step": example["step"],
            "observation": example["observation"],
            "previous_action": example["previous_action"],
            "previous_graph_seen": previous_graph_seen,
            "target_commands": example["target_commands"],
        })

    if args.output is None:
        args.output = os.path.splitext(args.input)[0] + ".cmd_gen.json"

    data = {
        "graph_index": graph_dataset.dumps(),
        "examples": dataset,
    }
    with open(args.output, "w") as f:
        json.dump(data, f)

    if args.verbose:
        print("This dataset has {:,} datapoints.".format(len(dataset)))
    def load_dataset_for_dgi(self, split):
        file_path = pjoin(self.data_path,
                          self.FILENAMES_MAP[self.graph_type][split])
        with open(file_path) as f:
            data = json.load(f)

        graph_dataset = GraphDataset.loads(data["graph_index"])
        self.dataset[split]["graph_dataset"] = graph_dataset

        desc = "Loading {}".format(os.path.basename(file_path))
        for example in tqdm(data["examples"], desc=desc):
            graph = example["graph"]
            self.dataset[split]["graph"].append(graph)
Ejemplo n.º 4
0
    def load_dataset_for_ap(self, split):
        file_path = pjoin(self.data_path, self.FILENAMES_MAP[self.graph_type][split])
        with open(file_path) as f:
            data = json.load(f)

        graph_dataset = GraphDataset.loads(data["graph_index"])
        self.dataset[split]["graph_dataset"] = graph_dataset

        desc = "Loading {}".format(os.path.basename(file_path))
        for example in tqdm(data["examples"], desc=desc):
            target_action = example["target_action"]
            curr_graph = example["current_graph"]
            prev_graph = example["previous_graph"]
            candidates = example["action_choices"]

            self.dataset[split]["current_graph"].append(curr_graph)
            self.dataset[split]["previous_graph"].append(prev_graph)
            self.dataset[split]["target_action"].append(target_action)
            self.dataset[split]["action_choices"].append(candidates)
    def load_dataset_for_cmd_gen(self, split):
        file_path = pjoin(self.data_path, self.FILENAMES_MAP[split])
        desc = "Loading {}".format(os.path.basename(file_path))
        print(desc)
        with open(file_path) as f:
            data = json.load(f)

        graph_dataset = GraphDataset.loads(data["graph_index"])
        self.dataset[split]["graph_dataset"] = graph_dataset

        for example in tqdm(data["examples"], desc=desc):
            observation = "{feedback} <sep> {action}".format(
                feedback=example["observation"],
                action=example["previous_action"])
            # Need to sort target commands to enable the seq2seq model to learn the ordering.
            target_commands = " <sep> ".join(
                sort_target_commands(example["target_commands"]))

            self.dataset[split]["observation_strings"].append(observation)
            self.dataset[split]["previous_triplets"].append(
                example["previous_graph_seen"])
            self.dataset[split]["target_commands"].append(target_commands)
Ejemplo n.º 6
0
def train(reload_dataset=False, pretrain_model_path=None, optim_fu='adam'):
    write = SummaryWriter()

    vis = visdom.Visdom(env="Graph_Attention_compression")
    viz = Visdom_line(vis=vis, win="Graph_Attention")

    # 一些配置
    DATA_DIR = './data/train_pairs'
    DICT_PATH = './checkpoint/dict_20000.pkl'
    EMBEDDING_PATH_RANDOM = './model/save_embedding_97and3.ckpt'
    SAVE_EMBEDDING = False
    RELOAD_DATASET = reload_dataset

    SAVE_DATASET_OBJ = './data/dataset.pkl'
    SAVE_MODEL_PATH = './checkpoint/Graph_Attn/'

    PRINT_STEP = 10
    SAVE_STEP = 1
    GPU_NUM = 0

    torch.manual_seed(2)
    torch.cuda.set_device(GPU_NUM)

    config = GraphAttenConfig()

    model = LSTMGraphAttn(config)
    model.cuda()

    if os.path.exists(SAVE_MODEL_PATH) is False:
        os.makedirs(SAVE_MODEL_PATH)

    # 读取embedding
    embed = get_word_embed().cuda()
    embed_flag = get_flag_embed().cuda()
    vocab = get_vocab()

    criterion = nn.CrossEntropyLoss(ignore_index=2)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    trainset = GraphDataset(vocab=vocab)
    trainloader = DataLoader(dataset=trainset,
                             batch_size=config.batch_size,
                             collate_fn=graph_fn,
                             pin_memory=True,
                             shuffle=True)

    global_step = 0
    for epoch in range(config.epoch):
        epoch_loss = 0
        for index, (src, trg, neighbor,
                    labels) in enumerate(tqdm(trainloader)):
            src = embed(src.cuda())
            trg = embed(trg.cuda())
            neighbor = embed(neighbor.cuda())

            flag4encoder = torch.zeros(src.shape[0], src.shape[1], 3).cuda()
            src = torch.cat([src, flag4encoder], dim=2)

            flag4decoder = torch.zeros([labels.shape[0], 1]).long()
            flag4decoder = torch.cat([flag4decoder, labels[:, :-1]],
                                     dim=1).cuda()
            flag4decoder = embed_flag(flag4decoder)

            flag4neighbor = torch.zeros(neighbor.shape[0], neighbor.shape[1],
                                        neighbor.shape[2], 3).cuda()
            neighbor = torch.cat([neighbor, flag4neighbor], dim=-1)

            trg = torch.cat([trg, flag4decoder], dim=2)
            labels = labels.cuda()

            out = model(src, trg, neighbor)
            out = out.view(-1, 2)
            labels = labels.view(-1)
            loss = criterion(out, labels)
            epoch_loss += loss.item()
            print(loss.item())
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            write.add_scalar('loss', loss.item(), global_step)
            global_step += 1

        model.save(SAVE_MODEL_PATH + 'model-' + str(epoch) + '.ckpt')
        write.add_scalar('epoch_loss', epoch_loss, epoch)
Ejemplo n.º 7
0
def main():
    # random.seed(0)
    # torch.manual_seed(0)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    # np.random.seed(0)
    # torch.cuda.manual_seed(0)

    # set all hyperparameters
    network_name = 'WRN_40_2'
    num_epochs = 35
    batch_size = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    n_retrain_epochs = 40 
    trials = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9]
    lr = 3e-4
    opt = "Adam"
    use_temp = False
    use_steps = False

    # set paths
    checkpointPath = './GNN_model/CIFAR10_checkpoints/CP__num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}__epoch_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps, '{}')    
    continue_train = False
    checkpointLoadPath = './GNN_model/CIFAR10_checkpoints/CP__num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}__epoch_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps, '20')

    # get GNN path
    info = networks_data.get(network_name)
    trained_model_path = info.get('trained_GNN_path').replace('.pt', '___num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps))

    # declare GNN model
    model = GNNPrunningNet(in_channels=6, out_channels=128).to(device)
    if opt == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        # lr = 0.1
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[int(elem*num_epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
    crit = GNN_prune_loss

    # declate TensorBoard writer
    summary_path = '{}-num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}/training'.format(network_name, num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps)
    writer = SummaryWriter(summary_path)


    root            = info.get('root')
    net_graph_path  = info.get('graph_path')
    sd_path         = info.get('sd_path')
    net             = info.get('network')
    orig_net_loss   = info.get('orig_net_loss') 

    isWRN = (network_name == "WRN_40_2")
    train_dataset   = GraphDataset(root, network_name, isWRN, net_graph_path)
    train_loader    = DataLoader(train_dataset, batch_size=batch_size)

    orig_net = net().to(device)
    orig_net.load_state_dict(torch.load(sd_path, map_location=device))

    model.train()

    dataset_name = info.get('dataset_name')
    network_train_data = datasets_train.get(dataset_name)

    print("Start training")

    if continue_train == True:
        cp = torch.load(checkpointLoadPath, map_location=device)
        trained_epochs = cp['epoch'] + 1
        sd = cp['model_state_dict']
        model.load_state_dict(sd)
        op_sd = cp['optimizer_state_dict']
        optimizer.load_state_dict(op_sd)
    else:
        trained_epochs = 0

    loss_all = 0.0
    data_all = 0.0
    sparse_all = 0.0
    if use_temp == True:
        T = 1.0
        if trained_epochs > 0:
            T = np.power(2, np.floor(trained_epochs / int(num_epochs/3)))

    for epoch in range(trained_epochs, num_epochs):
        
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data)

            if use_temp == True:
            # Use temperature
                nom = torch.pow((torch.exp(torch.tensor(T, device=device))), output)
                dom = torch.pow((torch.exp(torch.tensor(T, device=device))), output) + torch.pow((torch.exp(torch.tensor(T, device=device))), (1-output))
                output = nom/dom
                # continue as usual

            sparse_term, data_term, data_grad = crit(output, orig_net, orig_net_loss, network_name, network_train_data, device, gamma1=10, gamma2=0.1)

            if use_steps == True:
                if epoch % 3 == 0: # do 2 steps in data direction then 1 in sparsity
                    sparse_term.backward()
                else:
                    output.backward(data_grad)
            else:            
                sparse_term.backward(retain_graph=True)
                output.backward(data_grad)

            data_all += data.num_graphs * data_term.item()
            sparse_all += data.num_graphs * sparse_term.item()
            loss_all += data_all + sparse_all
            optimizer.step()
            
        print("epoch {}. total loss is: {}".format(epoch+1, (data_term.item() + sparse_term.item()) / len(train_dataset)))
        
        if opt != "Adam":
            scheduler.step()

        if use_temp == True:
        # increase temperature 3 times
            if (epoch+1) % int(num_epochs/3) == 0:
                T *= 2

        if epoch % 10 == 9:
            writer.add_scalars('Learning curve', {
            'loss data term': data_all/10,
            'loss sparsity term': sparse_all/10,
            'training loss': loss_all/10
            }, epoch+1)            

            # save checkpoint
            if opt == "Adam":
                torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_all,
                }, checkpointPath.format(epoch+1))
            else:
                torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_all,
                'scheduler_state_dict': scheduler.state_dict(),
                }, checkpointPath.format(epoch+1))

            loss_all = 0.0
            data_all = 0.0
            sparse_all = 0.0
            

    torch.save(model.state_dict(), trained_model_path)            

    print("Start evaluating")

    model.load_state_dict(torch.load(trained_model_path, map_location=device))

    model.eval()

    network_val_data = datasets_test.get(dataset_name)
    val_data_loader = torch.utils.data.DataLoader(network_val_data, batch_size=1024, shuffle=False, num_workers=8) 

    for trial, p_factor in enumerate(trials):
        with torch.no_grad():
            for data in train_loader:
                data = data.to(device)

                pred = model(data)

                prunedNet = getPrunedNet(pred, orig_net, network_name, prune_factor=p_factor).to(device)

        # Train the pruned network
        prunedNet.train()

        data_train_loader = torch.utils.data.DataLoader(network_train_data, batch_size=256, shuffle=False, num_workers=8) 
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(prunedNet.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[int(elem*n_retrain_epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)

        for epoch in range (n_retrain_epochs):
            for i, (images, labels) in enumerate(data_train_loader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                output = prunedNet(images)
                loss = criterion(output, labels)

                if i % 30 == 0:
                    print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch+1, i, loss.detach().cpu().item()))

                loss.backward()
                optimizer.step()

            scheduler.step()

        # Evaluate the pruned net
        with torch.no_grad():

            total_correct = 0
            cuda_time = 0.0            
            cpu_time = 0.0

            for i, (images, labels) in enumerate(val_data_loader):
                images, labels = images.to(device), labels.to(device)

                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    output = prunedNet(images)
                cuda_time += sum([item.cuda_time for item in prof.function_events])
                cpu_time += sum([item.cpu_time for item in prof.function_events])

                pred = output.detach().max(1)[1]
                total_correct += pred.eq(labels.view_as(pred)).sum()

            p_acc = float(total_correct) / len(network_val_data)
            p_num_params = gnp(prunedNet)
            p_cuda_time = cuda_time / len(network_val_data)
            p_cpu_time = cpu_time / len(network_val_data)

            print("The pruned network for prune factor {} accuracy is: {}".format(p_factor, p_acc))
            print("The pruned network number of parameters is: {}".format(p_num_params))
            print("The pruned network cuda time is: {}".format(p_cuda_time))
            print("The pruned network cpu time is: {}".format(p_cpu_time))

        # Evaluate the original net
        with torch.no_grad():

            total_correct = 0
            cuda_time = 0.0            
            cpu_time = 0.0
            
            for i, (images, labels) in enumerate(val_data_loader):
                images, labels = images.to(device), labels.to(device)

                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    output = orig_net(images)
                cuda_time += sum([item.cuda_time for item in prof.function_events])
                cpu_time += sum([item.cpu_time for item in prof.function_events])

                pred = output.detach().max(1)[1]
                total_correct += pred.eq(labels.view_as(pred)).sum()

            o_acc = float(total_correct) / len(network_val_data)
            o_num_params = gnp(orig_net)
            o_cuda_time = cuda_time / len(network_val_data)
            o_cpu_time = cpu_time / len(network_val_data)

            print("The original network accuracy is: {}".format(o_acc))
            print("The original network number of parameters is: {}".format(o_num_params))
            print("The original network cuda time is: {}".format(o_cuda_time))
            print("The original network cpu time is: {}".format(o_cpu_time))

        writer.add_scalars('Network accuracy', {
            'original': o_acc,
            'pruned': p_acc
            }, 100*p_factor)
        writer.add_scalars('Network number of parameters', {
            'original': o_num_params,
            'pruned': p_num_params
            }, 100*p_factor)
        writer.add_scalars('Network GPU time', {
            'original': o_cuda_time,
            'pruned': p_cuda_time
            }, 100*p_factor)
        writer.add_scalars('Network CPU time', {
            'original': o_cpu_time,
            'pruned': p_cpu_time
            }, 100*p_factor)

    writer.close()