Пример #1
0
    def __init__(self, config, model_config_branch, model_name='contentgen'):

        logger.debug("{} initialized".format(__name__))

        self.model_name = model_name
        self.config = config
        self.use_gpu = config.use_gpu
        self.gpus = config.gpus
        self.device = config.device
        self.seed = config.seed
        self.random_gen = random.Random(self.seed)

        self.model_dir = os.path.join(self.config.models_dir, model_name)

        #config
        self.model_config_path = os.path.join(self.model_dir, 'config.yaml')
        assert os.path.exists(
            self.model_config_path), "Invalid config file: {}".format(
                self.model_config_path)
        self.model_config = read_config(self.model_config_path)

        self.batch_size = 1
        self.temperature = model_config_branch.temperature
        self.max_gen_len = model_config_branch.max_gen_len
        self.save_sample = model_config_branch.save_sample

        self.file_exts = self.model_config.dataset.file_ext
        self.seq_len = self.model_config.dataset.seq_len
        self.end_token = self.model_config.dataset.end_token
        self.all_letters = list(string.printable) + [self.end_token]
        self.n_letters = len(self.all_letters) + 1  #EOS MARKER

        #snapshot
        self.model_snapshot = os.path.join(self.model_dir,
                                           model_config_branch.model_snapshot)
        assert os.path.exists(
            self.model_snapshot), "Invalid snapshot: {}".format(
                self.model_snapshot)

        #architecture
        self.model_file = model_config_branch.model_file
        self.model_arch = os.path.join(self.model_dir, self.model_file)
        assert os.path.exists(self.model_arch), "Invalid arch: {}".format(
            self.model_arch)

        #initialize module and model
        model_object = self.model_config.model.model_name
        spec = importlib.util.spec_from_file_location(model_object,
                                                      self.model_arch)

        model_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(model_module)

        init_method = getattr(model_module, model_object)
        self.model_func = init_method(self.model_config, self.n_letters - 1,
                                      self.seq_len)

        #load checkpoints
        load_model(self.model_func, self.model_snapshot, self.device)
Пример #2
0
  def __init__(self, config, model_config_branch, model_name='gen'):

    logger.debug("{} initialized".format(__name__))

    self.model_name = model_name
    self.config = config
    self.use_gpu = config.use_gpu
    self.gpus = config.gpus
    self.device = config.device
    self.seed = config.seed
    self.random_gen = random.Random(self.seed)

    self.model_dir = os.path.join(self.config.models_dir, model_name)


    #config
    self.model_config_path = os.path.join(self.model_dir, 'config.yaml')
    assert os.path.exists(self.model_config_path), "Invalid config file: {}".format(self.model_config_path)    
    self.model_config = read_config(self.model_config_path)
    
    self.samples = config.samples
    self.batch_size = model_config_branch.batch_size
    #self.num_node_limit = model_config_branch.num_node_limit
    self.draw_settings = model_config_branch.draw_settings

    #snapshot
    self.model_snapshot = os.path.join(self.model_dir, model_config_branch.model_snapshot)
    assert os.path.exists(self.model_snapshot), "Invalid snapshot: {}".format(self.model_snapshot)
    
    #architecture
    self.model_file = model_config_branch.model_file
    self.model_arch = os.path.join(self.model_dir, self.model_file)
    assert os.path.exists(self.model_arch), "Invalid arch: {}".format(self.model_arch)

    #initialize module and model
    model_object = self.model_config.model.model_name
    spec = importlib.util.spec_from_file_location(
     model_object, self.model_arch
     )

    model_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(model_module)

    init_method = getattr(model_module, model_object)
    self.model_func = init_method(self.model_config)
    
    #load checkpoints
    load_model(self.model_func, self.model_snapshot, self.device)

    m_pkl = os.path.join(self.model_dir, 'metrics_gaussian_tv_0000100.pkl')
    self.metrics = pickle.load(open(m_pkl, "rb"))
Пример #3
0
    def train(self):
        train_data_dict, val_data_dict, test_data_dict = load_data_dicts(
            10000, 50000, 10000)

        train_data = train_data_dict['X'].astype(np.float32)
        val_data = val_data_dict['X'].astype(np.float32)
        test_data = test_data_dict['X'].astype(np.float32)
        train_label = np.argmax(train_data_dict['T'], axis=1)
        val_label = np.argmax(val_data_dict['T'], axis=1)
        test_label = np.argmax(test_data_dict['T'], axis=1)

        # create models
        model = eval(self.model_conf.name)(self.config)

        # create optimizer
        params = model.hyper_param
        if self.train_conf.meta_optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.meta_lr,
                                  momentum=self.train_conf.meta_momentum,
                                  weight_decay=0.0)
        elif self.train_conf.meta_optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.meta_lr,
                                   weight_decay=0.0)
        else:
            raise ValueError("Non-supported meta optimizer!")

        # reset gradient
        optimizer.zero_grad()

        # resume training
        if self.train_conf.is_resume:
            load_model(model,
                       self.train_conf.resume_model,
                       optimizer=optimizer)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).cuda()

        # Training Loop
        best_meta_train_loss = 10.0
        results = defaultdict(list)
        for ii in range(self.train_conf.max_meta_iter):
            optimizer.zero_grad()
            train_loss, meta_train_loss, grad_hyper = model(
                train_data, train_label, val_data, val_label, ii)

            # decay meta learning rate
            if ii == 50:
                for pg in optimizer.param_groups:
                    pg['lr'] *= 0.1

            # meta optimization step
            for p, g in zip(params, grad_hyper):
                p.grad = g

            optimizer.step()
            print('Meta validation loss @step {} = {}'.format(
                ii + 1, meta_train_loss))
            if meta_train_loss < best_meta_train_loss:
                best_meta_train_loss = meta_train_loss

            results['meta_train_loss'] += [meta_train_loss]
            results['last_train_loss'] = train_loss

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        return best_meta_train_loss
Пример #4
0
    def eval(self):
        ## This should generate a sample for every single type

        logger.debug('starting eval')

        self.config.save_dir = os.path.join(self.test_conf.test_exp_dir,
                                            self.test_conf.test_model_dir)

        model_file = os.path.join(self.config.save_dir,
                                  self.test_conf.test_model_name)
        ss_dist_file = os.path.join(self.config.save_dir, 'ss_dist.p')
        batch_size = self.test_conf.test_batch_size
        max_gen_length = self.test_conf.max_gen_length
        num_gen_samples = self.test_conf.num_gen_samples
        #start_string = self.test_conf.start_string
        temperature = self.test_conf.temperature

        #Calc the start string dist for each kind of file
        if not os.path.exists(ss_dist_file):
            ss_dist = self.start_string_dist()
            pickle.dump(ss_dist, open(ss_dist_file, "wb"))
        else:
            ss_dist = pickle.load(open(ss_dist_file, "rb"))

        assert batch_size == 1, "batch size needs to be one"

        logger.debug("Loading Model: {}".format(model_file))
        logger.debug("Batch Size: {} max_length: {}".format(
            batch_size, max_gen_length))

        model = eval(self.model_conf.model_name)(self.config, self.n_letters)

        model = load_model(model, model_file, self.device)['model']

        model = model.to(self.device)
        model.eval()

        text_generated = {
            'dir': {i: []
                    for i in range(self.max_depth + 1)},
            'file': {i: []
                     for i in range(1, self.max_depth + 1)},
        }

        with torch.no_grad():
            for node_type in text_generated.keys():
                for depth in text_generated[node_type]:
                    n_samples = 0
                    while n_samples < num_gen_samples:
                        #sample start letter from distributiion
                        start_letter_dist = ss_dist[node_type][depth]
                        #m = torch.distributions.Categorical(start_letter_dist)
                        #start_letter_idx = m.sample()
                        sample = np.random.multinomial(1,
                                                       start_letter_dist,
                                                       size=1)
                        start_letter_idx = np.argmax(sample)

                        input_eval = torch.LongTensor([start_letter_idx
                                                       ]).view(1, -1)
                        nt_eval = torch.LongTensor(
                            [self.node_types.index(node_type)])
                        depth_eval = torch.LongTensor([depth])

                        #Move to gpu
                        input_eval = input_eval.pin_memory().to(
                            0, non_blocking=True)
                        nt_eval = nt_eval.pin_memory().to(0, non_blocking=True)
                        depth_eval = depth_eval.pin_memory().to(
                            0, non_blocking=True)
                        hidden = model.initHidden().pin_memory().to(
                            0, non_blocking=True)

                        text_gen = []
                        for _ in range(max_gen_length):
                            pred, hidden = model(nt_eval, depth_eval,
                                                 input_eval, hidden)

                            pred = pred[0].squeeze()
                            pred = pred / temperature
                            m = torch.distributions.Categorical(logits=pred)
                            pred_id = m.sample()

                            #print(i, pred_id)
                            next_char = self.all_letters[pred_id.item()]

                            if next_char == self.end_token: break

                            text_gen.append(next_char)
                            input_eval = pred_id.view(-1, 1)

                        full_name = self.all_letters[
                            start_letter_idx] + ''.join(text_gen)
                        if full_name not in text_generated[node_type][depth]:
                            text_generated[node_type][depth].append(full_name)
                            n_samples += 1

        save_name = os.path.join(
            self.config.save_dir,
            'name_gen_sample_{}.txt'.format(str(int(time.time()))))
        with open(save_name, 'w') as f:
            for node_type in text_generated.keys():
                f.write("-" * 100)
                f.write("\nNode Type: {} \n".format(node_type))
                for depth in text_generated[node_type]:
                    samples = text_generated[node_type][depth]
                    out = ', '.join(samples)
                    f.write("{:3} | {}\n".format(depth, out))

        logger.info("Saved sample @ {}".format(save_name))
Пример #5
0
    def train(self):
        # create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            split='train')
        dev_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                          split='dev')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)
        dev_loader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=False,
            num_workers=self.train_conf.num_workers,
            collate_fn=dev_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).cuda()

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=10, is_decrease=False)

        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_steps,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        if self.train_conf.is_resume:
            load_model(model,
                       self.train_conf.resume_model,
                       optimizer=optimizer)

        # Training Loop
        iter_count = 0
        best_val_loss = np.inf
        results = defaultdict(list)
        for epoch in range(self.train_conf.max_epoch):
            # validation
            if (epoch + 1) % self.train_conf.valid_epoch == 0 or epoch == 0:
                model.eval()
                val_loss = []

                for data in tqdm(dev_loader):
                    if self.use_gpu:
                        data['node_feat'], data['node_mask'], data[
                            'label'] = data_to_gpu(data['node_feat'],
                                                   data['node_mask'],
                                                   data['label'])

                        if self.model_conf.name == 'LanczosNet':
                            data['L'], data['D'], data['V'] = data_to_gpu(
                                data['L'], data['D'], data['V'])
                        elif self.model_conf.name == 'GraphSAGE':
                            data['nn_idx'], data[
                                'nonempty_mask'] = data_to_gpu(
                                    data['nn_idx'], data['nonempty_mask'])
                        elif self.model_conf.name == 'GPNN':
                            data['L'], data['L_cluster'], data[
                                'L_cut'] = data_to_gpu(data['L'],
                                                       data['L_cluster'],
                                                       data['L_cut'])
                        else:
                            data['L'] = data_to_gpu(data['L'])[0]

                    with torch.no_grad():
                        if self.model_conf.name == 'AdaLanczosNet':
                            pred, _ = model(data['node_feat'],
                                            data['L'],
                                            label=data['label'],
                                            mask=data['node_mask'])
                        elif self.model_conf.name == 'LanczosNet':
                            pred, _ = model(data['node_feat'],
                                            data['L'],
                                            data['D'],
                                            data['V'],
                                            label=data['label'],
                                            mask=data['node_mask'])
                        elif self.model_conf.name == 'GraphSAGE':
                            pred, _ = model(data['node_feat'],
                                            data['nn_idx'],
                                            data['nonempty_mask'],
                                            label=data['label'],
                                            mask=data['node_mask'])
                        elif self.model_conf.name == 'GPNN':
                            pred, _ = model(data['node_feat'],
                                            data['L'],
                                            data['L_cluster'],
                                            data['L_cut'],
                                            label=data['label'],
                                            mask=data['node_mask'])
                        else:
                            pred, _ = model(data['node_feat'],
                                            data['L'],
                                            label=data['label'],
                                            mask=data['node_mask'])

                    curr_loss = (pred - data['label']
                                 ).abs().cpu().numpy() * self.const_factor
                    val_loss += [curr_loss]

                val_loss = float(np.mean(np.concatenate(val_loss)))
                logger.info("Avg. Validation MAE = {}".format(val_loss))
                self.writer.add_scalar('val_loss', val_loss, iter_count)
                results['val_loss'] += [val_loss]

                # save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    snapshot(model.module if self.use_gpu else model,
                             optimizer,
                             self.config,
                             epoch + 1,
                             tag='best')

                logger.info(
                    "Current Best Validation MAE = {}".format(best_val_loss))

                # check early stop
                if early_stop.tick([val_loss]):
                    snapshot(model.module if self.use_gpu else model,
                             optimizer,
                             self.config,
                             epoch + 1,
                             tag='last')
                    self.writer.close()
                    break

            # training
            model.train()
            lr_scheduler.step()
            for data in train_loader:
                optimizer.zero_grad()

                if self.use_gpu:
                    data['node_feat'], data['node_mask'], data[
                        'label'] = data_to_gpu(data['node_feat'],
                                               data['node_mask'],
                                               data['label'])

                    if self.model_conf.name == 'LanczosNet':
                        data['L'], data['D'], data['V'] = data_to_gpu(
                            data['L'], data['D'], data['V'])
                    elif self.model_conf.name == 'GraphSAGE':
                        data['nn_idx'], data['nonempty_mask'] = data_to_gpu(
                            data['nn_idx'], data['nonempty_mask'])
                    elif self.model_conf.name == 'GPNN':
                        data['L'], data['L_cluster'], data[
                            'L_cut'] = data_to_gpu(data['L'],
                                                   data['L_cluster'],
                                                   data['L_cut'])
                    else:
                        data['L'] = data_to_gpu(data['L'])[0]

                if self.model_conf.name == 'AdaLanczosNet':
                    _, train_loss = model(data['node_feat'],
                                          data['L'],
                                          label=data['label'],
                                          mask=data['node_mask'])
                elif self.model_conf.name == 'LanczosNet':
                    _, train_loss = model(data['node_feat'],
                                          data['L'],
                                          data['D'],
                                          data['V'],
                                          label=data['label'],
                                          mask=data['node_mask'])
                elif self.model_conf.name == 'GraphSAGE':
                    _, train_loss = model(data['node_feat'],
                                          data['nn_idx'],
                                          data['nonempty_mask'],
                                          label=data['label'],
                                          mask=data['node_mask'])
                elif self.model_conf.name == 'GPNN':
                    _, train_loss = model(data['node_feat'],
                                          data['L'],
                                          data['L_cluster'],
                                          data['L_cut'],
                                          label=data['label'],
                                          mask=data['node_mask'])
                else:
                    _, train_loss = model(data['node_feat'],
                                          data['L'],
                                          label=data['label'],
                                          mask=data['node_mask'])

                # assign gradient
                train_loss.backward()
                optimizer.step()
                train_loss = float(train_loss.data.cpu().numpy())
                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                # display loss
                if (iter_count + 1) % self.train_conf.display_iter == 0:
                    logger.info(
                        "Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count + 1, train_loss))

                iter_count += 1

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model, optimizer,
                         self.config, epoch + 1)

        results['best_val_loss'] += [best_val_loss]
        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()
        logger.info("Best Validation MAE = {}".format(best_val_loss))

        return best_val_loss
Пример #6
0
  def eval(self):
    logger.debug('starting eval')
    
    self.config.save_dir = self.test_conf.test_model_dir

    model_file = os.path.join(self.config.save_dir, self.test_conf.test_model_name)
    batch_size = self.test_conf.test_batch_size
    num_gen = self.test_conf.num_test_gen
    start_string = self.test_conf.start_string
    temperature = self.test_conf.temperature

    assert batch_size == 1, "batch size needs to be one"

    logger.debug("Loading Model: {}".format(model_file))
    logger.debug("Batch Size: {} Num of characters to generate: {} Start string: {}".format(
      batch_size, num_gen, start_string
    ))

    model = eval(self.model_conf.model_name)(
      self.config,
      self.n_letters - 1,
      self.seq_len
      )

    model = load_model(
        model,
        model_file,
        self.device
        )['model']
      
    model = model.to(self.device)
    model.eval()
    
    text_generated = []
    input_eval = torch.LongTensor([self.all_letters.index(s) for s in start_string]).view(1, -1)
    ext = torch.LongTensor([self.file_ext.index(self.test_ext)])

    ext_tensor = ext.pin_memory().to(0, non_blocking=True)
    input_eval = input_eval.pin_memory().to(0,non_blocking=True)
    hidden = model.initHidden().pin_memory().to(0,non_blocking=True)
    
    with torch.no_grad():
      for i in range(num_gen):
        #print(i, input_eval.shape)
        pred, hidden  = model(ext_tensor, input_eval, hidden)
        pred = pred[0].squeeze()
        pred = pred / temperature
        
        m = torch.distributions.Categorical(logits=pred)
        pred_id = m.sample()
        if i == 0 and len(start_string) > 1:
            pred_id = pred_id[-1]


        #print(i, pred_id)
        next_char = self.all_letters[pred_id.item()]
        text_generated.append(next_char)

        input_eval = pred_id.view(-1,1)

    full_code = start_string + ''.join(text_generated)

    save_name = os.path.join(self.config.save_dir, 'sample_{}.{}'.format(self.test_nb, self.test_ext))

    with open(save_name, 'wb') as f:
      f.write(full_code.encode('utf-8'))
Пример #7
0
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            model.train()
            lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['node_idx_gnn'] = batch_data[dd][ff][
                                'node_idx_gnn'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['node_idx_feat'] = batch_data[dd][ff][
                                'node_idx_feat'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['label'] = batch_data[dd][ff][
                                'label'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['att_idx'] = batch_data[dd][ff][
                                'att_idx'].pin_memory().to(gpu_id,
                                                           non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append((data, ))

                    if batch_fwd:
                        train_loss = model(*batch_fwd).mean()
                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1
Пример #8
0
    def train(self):
        # create data loader
        train_dataset = BinaryMNIST(self.dataset_conf.path,
                                    num_imgs=self.dataset_conf.num_imgs,
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)
        val_dataset = BinaryMNIST(self.dataset_conf.path,
                                  num_imgs=self.dataset_conf.num_imgs,
                                  train=False,
                                  transform=transforms.ToTensor(),
                                  download=True)

        train_loader = DataLoader(train_dataset,
                                  batch_size=self.train_conf.batch_size,
                                  shuffle=self.train_conf.shuffle,
                                  num_workers=self.train_conf.num_workers,
                                  drop_last=False)
        val_loader = DataLoader(val_dataset,
                                batch_size=self.train_conf.batch_size,
                                shuffle=False,
                                num_workers=self.train_conf.num_workers,
                                drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)

        # create optimizer
        params = model.parameters()
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_steps,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        if self.train_conf.is_resume:
            load_model(model,
                       self.train_conf.resume_model,
                       optimizer=optimizer)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).cuda()

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(self.train_conf.max_epoch):
            # validation
            if (epoch + 1) % self.train_conf.valid_epoch == 0 or epoch == 0:
                model.eval()
                val_loss = []
                val_counter = 0
                for imgs, labels in val_loader:
                    if self.use_gpu:
                        imgs, labels = imgs.cuda(), labels.cuda()

                    imgs, labels = imgs.float(), labels.float()
                    imgs_corrupt = self.rand_corrupt(
                        imgs, corrupt_level=self.dataset_conf.corrupt_level)
                    curr_loss, imgs_memory, _, _ = model(imgs_corrupt)
                    img_recover = imgs_memory[-self.model_conf.input_dim:]
                    img_recover_show = img_recover.clone().detach()
                    img_recover_show.requires_grad = False
                    img_recover_show[img_recover_show >= 0.5] = 1.0
                    img_recover_show[img_recover_show < 0.5] = 0.0
                    val_loss += [float(curr_loss.data.cpu().numpy())]
                    val_counter += 1

                val_loss = float(np.mean(val_loss))
                logger.info("Avg. Validation Loss = {}".format(
                    np.log10(val_loss)))
                results['val_loss'] += [val_loss]
                model.train()

            # training
            lr_scheduler.step()
            for imgs, labels in train_loader:
                if self.use_gpu:
                    imgs, labels = imgs.cuda(), labels.cuda()

                imgs, labels = imgs.float(), labels.float()
                optimizer.zero_grad()
                train_loss, imgs_memory, diff_norm, grad = model(imgs)

                for pp, ww in zip(model.parameters(), grad):
                    pp.grad = ww

                optimizer.step()
                train_loss = float(train_loss.data.cpu().numpy())
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                # display loss
                if iter_count % self.train_conf.display_iter == 0:
                    logger.info(
                        "Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count + 1, np.log10(train_loss)))

                    tmp_key = 'diff_norm_{}'.format(iter_count + 1)
                    results[tmp_key] = diff_norm

                iter_count += 1

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model, optimizer,
                         self.config, epoch + 1)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
Пример #9
0
import torch

from utils.data_reader import get_train_dev_test_data
from utils.train_helper import load_model, eval_model, get_data_loader

train_data, dev_data, test_data = get_train_dev_test_data()
model = load_model("model/checkpoints/DeepCoNN_20200601215955.pt")
model.to(model.config.device)
loss = torch.nn.MSELoss()
data_iter = get_data_loader(test_data, model.config)
print(eval_model(model, data_iter, loss))
Пример #10
0
    def test(self):
        self.config.save_dir_train = self.test_conf.test_model_dir

        ### test dataset
        test_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                           self.graphs_test,
                                                           tag='test')
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.test_conf.batch_size,
            shuffle=False,
            num_workers=self.train_conf.num_workers,
            collate_fn=test_dataset.collate_fn,
            drop_last=False)

        ### load model
        args = self.config.model
        n_labels = self.dataset_conf.max_m + self.dataset_conf.max_n
        G = define_G(args.nz, args.ngf, args.netG, args.final_activation,
                     args.norm_G)
        model_file_G = os.path.join(self.config.save_dir_train,
                                    self.test_conf.test_model_name)

        load_model(G, model_file_G, self.device)
        if self.use_gpu:
            G = G.cuda()  #nn.DataParallel(G).to(self.device)
        G.train()

        if not hasattr(self.config.test,
                       'hard_multi') or not self.config.test.hard_multi:
            hard_thre_list = [None]
        else:
            hard_thre_list = np.arange(0.5, 1, 0.1)

        for test_hard_idx, hard_thre in enumerate(hard_thre_list):
            logger.info('Test pass {}. Hard threshold {}'.format(
                test_hard_idx, hard_thre))
            ### Generate Graphs
            A_pred = []
            gen_run_time = []

            for batch_data in test_loader:
                # asserted in arg helper
                ff = 0

                with torch.no_grad():
                    data = {}
                    data['adj'] = batch_data[ff]['adj'].pin_memory().to(
                        self.config.device, non_blocking=True)
                    data['m'] = batch_data[ff]['m'].to(self.config.device,
                                                       non_blocking=True)
                    data['n'] = batch_data[ff]['n'].to(self.config.device,
                                                       non_blocking=True)

                    batch_size = data['adj'].size(0)

                    i_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_m),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    i_onehot.scatter_(1, data['m'][:, None] - 1, 1)
                    j_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_n),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    j_onehot.scatter_(1, data['n'][:, None] - 1, 1)
                    y_onehot = torch.cat((i_onehot, j_onehot), dim=1)

                    if args.nz > n_labels:
                        noise = torch.randn(
                            (batch_size, args.nz - n_labels, 1, 1),
                            requires_grad=True).to(self.config.device,
                                                   non_blocking=True)
                        z_input = torch.cat(
                            (y_onehot.view(batch_size, n_labels, 1, 1), noise),
                            dim=1)
                    else:
                        z_input = y_onehot.view(batch_size, n_labels, 1, 1)

                    start_time = time.time()
                    output = G(z_input).squeeze(1)  # (B, 1, n, n)
                    if self.model_conf.final_activation == 'tanh':
                        output = (output + 1) / 2
                    if self.model_conf.is_sym:
                        output = torch.tril(output, diagonal=-1)
                        output = output + output.transpose(1, 2)
                    gen_run_time += [time.time() - start_time]

                    if hard_thre is not None:
                        A_pred += [(output[batch_idx, ...] >
                                    hard_thre).long().cpu().numpy()
                                   for batch_idx in range(batch_size)]
                    else:
                        A_pred += [
                            torch.bernoulli(output[batch_idx,
                                                   ...]).long().cpu().numpy()
                            for batch_idx in range(batch_size)
                        ]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [get_graph(aa) for aa in A_pred]

            ### Visualize Generated Graphs
            if self.is_vis:
                num_col = self.vis_num_row
                num_row = self.num_vis // num_col
                test_epoch = self.test_conf.test_model_name
                test_epoch = test_epoch[test_epoch.rfind('_') +
                                        1:test_epoch.find('.pth')]
                if hard_thre is not None:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                else:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))

                # remove isolated nodes for better visulization
                graphs_pred_vis = [
                    copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
                ]

                if self.better_vis:
                    # actually not necessary with the following largest connected component selection
                    for gg in graphs_pred_vis:
                        gg.remove_nodes_from(list(nx.isolates(gg)))

                # display the largest connected component for better visualization
                vis_graphs = []
                for gg in graphs_pred_vis:
                    if self.better_vis:
                        CGs = [
                            gg.subgraph(c) for c in nx.connected_components(gg)
                        ]
                        CGs = sorted(CGs,
                                     key=lambda x: x.number_of_nodes(),
                                     reverse=True)
                        vis_graphs += [CGs[0]]
                    else:
                        vis_graphs += [gg]
                print('number of nodes after better vis',
                      [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

                if self.is_single_plot:
                    # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
                    draw_graph_list(vis_graphs,
                                    num_row,
                                    num_col,
                                    fname=save_name2,
                                    layout='spring')
                else:
                    # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
                    draw_graph_list_separate(vis_graphs,
                                             fname=save_name2[:-4],
                                             is_single=True,
                                             layout='spring')

                if test_hard_idx == 0:
                    save_name = os.path.join(self.config.save_dir_train,
                                             'train_graphs.png')

                    if self.is_single_plot:
                        draw_graph_list(self.graphs_train[:self.num_vis],
                                        num_row,
                                        num_col,
                                        fname=save_name,
                                        layout='spring')
                    else:
                        draw_graph_list_separate(
                            self.graphs_train[:self.num_vis],
                            fname=save_name[:-4],
                            is_single=True,
                            layout='spring')

            ### Evaluation
            if self.config.dataset.name in ['lobster']:
                acc = eval_acc_lobster_graph(graphs_gen)
                logger.info(
                    'Validity accuracy of generated graphs = {}'.format(acc))

            num_nodes_gen = [len(aa) for aa in graphs_gen]

            # Compared with Validation Set
            num_nodes_dev = [len(gg.nodes)
                             for gg in self.graphs_dev]  # shape B X 1
            mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
                self.graphs_dev, graphs_gen, degree_only=False)
            mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                            [np.bincount(num_nodes_gen)],
                                            kernel=gaussian_emd)

            # Compared with Test Set
            num_nodes_test = [len(gg.nodes)
                              for gg in self.graphs_test]  # shape B X 1
            mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
                self.graphs_test, graphs_gen, degree_only=False)
            mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                             [np.bincount(num_nodes_gen)],
                                             kernel=gaussian_emd)

            logger.info(
                "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_dev), Decimal(mmd_degree_dev),
                        Decimal(mmd_clustering_dev), Decimal(mmd_4orbits_dev),
                        Decimal(mmd_spectral_dev)))
            logger.info(
                "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_test), Decimal(mmd_degree_test),
                        Decimal(mmd_clustering_test),
                        Decimal(mmd_4orbits_test), Decimal(mmd_spectral_test)))
Пример #11
0
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,  # true for grid
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        # model = eval(self.model_conf.name)(self.config)
        args = self.config.model
        n_labels = self.dataset_conf.max_m + self.dataset_conf.max_n
        G = define_G(args.nz, args.ngf, args.netG, args.final_activation,
                     args.norm_G)
        D = define_D(args.ndf, args.netD, norm=args.norm_D)

        ### define losses
        criterionGAN = GANLoss(args.gan_mode)
        rote_loss = nn.L1Loss(reduction='none')
        if args.sparsity > 0.:
            sparse_loss = nn.L1Loss()

        if self.use_gpu:
            # G = DataParallel(G).to(self.device)
            # D = DataParallel(D).to(self.device)
            G = G.cuda()
            D = D.cuda()
            criterionGAN = criterionGAN.to(self.device)
            rote_loss = rote_loss.cuda()
            if args.sparsity > 0.:
                sparse_loss = sparse_loss.cuda()

        G.train()
        D.train()

        # create optimizer
        G_params = filter(lambda p: p.requires_grad, G.parameters())
        D_params = filter(lambda p: p.requires_grad, D.parameters())
        optimizer_G = optim.Adam(G_params,
                                 lr=self.train_conf.lr,
                                 betas=(self.train_conf.beta1, 0.999))
        optimizer_D = optim.Adam(D_params,
                                 lr=self.train_conf.lr,
                                 betas=(self.train_conf.beta1, 0.999))
        fake_pool = ImagePool(args.pool_size)

        # resume training
        # TODO: record resume_epoch to the saved file
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file_G = os.path.join(self.train_conf.resume_dir,
                                        'G_' + self.train_conf.resume_model)
            model_file_D = os.path.join(self.train_conf.resume_dir,
                                        'D_' + self.train_conf.resume_model)
            load_model(G, model_file_G, self.device, optimizer=optimizer_G)
            load_model(D, model_file_D, self.device, optimizer=optimizer_D)
            resume_epoch = int(
                osp.splitext(self.train_conf.resume_model)[0].split('_')[-1])
            #original: self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0  # iter idx thoughout the whole training
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            train_iterator = train_loader.__iter__()

            for batch_data in train_iterator:
                set_requires_grad(D, False)
                # set_requires_grad(G, True)
                optimizer_G.zero_grad()

                iter_count += 1
                # assert in arg helper
                ff = 0
                data = {}
                data['adj'] = batch_data[ff]['adj'].pin_memory().to(
                    self.config.device, non_blocking=True)
                data['m'] = batch_data[ff]['m'].to(self.config.device,
                                                   non_blocking=True)
                data['n'] = batch_data[ff]['n'].to(self.config.device,
                                                   non_blocking=True)

                batch_size = data['adj'].size(0)

                i_onehot = torch.zeros(
                    (batch_size, self.dataset_conf.max_m),
                    requires_grad=True).pin_memory().to(self.config.device,
                                                        non_blocking=True)
                i_onehot.scatter_(1, data['m'][:, None] - 1, 1)
                j_onehot = torch.zeros(
                    (batch_size, self.dataset_conf.max_n),
                    requires_grad=True).pin_memory().to(self.config.device,
                                                        non_blocking=True)
                j_onehot.scatter_(1, data['n'][:, None] - 1, 1)
                y_onehot = torch.cat((i_onehot, j_onehot), dim=1)

                if args.nz > n_labels:
                    noise = torch.randn(
                        (batch_size, args.nz - n_labels, 1, 1),
                        requires_grad=True).to(self.config.device,
                                               non_blocking=True)
                    z_input = torch.cat(
                        (y_onehot.view(batch_size, n_labels, 1, 1), noise),
                        dim=1)
                else:
                    z_input = y_onehot.view(batch_size, n_labels, 1, 1)

                output = G(z_input)  # (B, 1, n, n)
                if self.model_conf.is_sym:
                    output = torch.tril(output, diagonal=-1)
                    output = output + output.transpose(2, 3)

                loss_G = 0.
                if args.sparsity > 0:
                    loss_G_sparse = sparse_loss(
                        output,
                        torch.tensor(0.).expand_as(output).cuda())
                    loss_G += args.sparsity * loss_G_sparse
                if args.lambda_rote > 0:
                    if args.final_activation == 'tanh':
                        tmp_obj = (data['adj'] - 0.5) * 2
                    else:
                        tmp_obj = data['adj']
                    loss_G_rote = rote_loss(output, tmp_obj)
                    rote_mask = (loss_G_rote > 0.2).type_as(loss_G_rote)
                    loss_G_rote = (loss_G_rote * rote_mask).mean()
                    loss_G += args.lambda_rote * loss_G_rote

                # backward G

                loss_G_GAN = criterionGAN(D(output), True)
                loss_G += loss_G_GAN
                loss_G.backward()
                optimizer_G.step()

                # backward D
                set_requires_grad(D, True)
                # set_requires_grad(G, False)
                optimizer_D.zero_grad()
                real = data['adj']

                if args.final_activation == 'sigmoid':
                    ones_soft = torch.rand_like(real) * 0.1 + 0.9
                    zeros_soft = torch.rand_like(real) * 0.1
                elif args.final_activation == 'tanh':
                    ones_soft = torch.rand_like(real) * 0.2 + 0.8
                    zeros_soft = -(torch.rand_like(real) * 0.2 + 0.8)
                ones_mask = (real == 1.)
                zeros_mask = (real == 0.)
                real[ones_mask] = ones_soft[ones_mask]
                real[zeros_mask] = zeros_soft[zeros_mask]
                if self.model_conf.is_sym:
                    real = torch.tril(real, diagonal=-1)
                    real = real + real.transpose(2, 3)
                pred_real = D(real)
                loss_D_real = criterionGAN(pred_real, True)
                # Fake
                if args.pool_size:
                    queried_fake = fake_pool.query(output.detach())
                else:
                    queried_fake = output.detach()
                pred_fake = D(queried_fake)
                loss_D_fake = criterionGAN(pred_fake, False)
                # Combined loss and calculate gradients
                loss_D = (loss_D_real + loss_D_fake) * 0.5
                loss_D.backward()
                optimizer_D.step()

                # reduce
                self.writer.add_scalar('train_loss_G', loss_G.item(),
                                       iter_count)
                self.writer.add_scalar('train_loss_D', loss_D.item(),
                                       iter_count)
                results['train_loss_G'] += [loss_G]
                results['train_loss_D'] += [loss_D]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "@ epoch {:04d} iter {:08d} loss_G: {:.5f}, loss_G_GAN: {:.5f}, loss_D: {:.5f}, loss_D_real: {:.5f}, loss_D_fake: {:.5f}"
                        .format(epoch + 1, iter_count, loss_G.item(),
                                loss_G_GAN.item(), loss_D.item(),
                                loss_D_real.item(), loss_D_fake.item()))
                    if args.lambda_rote > 0:
                        logger.info(
                            "@ epoch {:04d} iter {:08d} loss_rote: {:.5f}".
                            format(epoch + 1, iter_count, loss_G_rote.item()))

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(G,
                         optimizer_G,
                         self.config,
                         epoch + 1,
                         fname_prefix='G_')
                snapshot(D,
                         optimizer_G,
                         self.config,
                         epoch + 1,
                         fname_prefix='D_')

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()
        return 1
Пример #12
0
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)
        print('number of parameters : {}'.format(
            sum([np.prod(x.shape) for x in model.parameters()])))

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)

        from copy import deepcopy
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            deepcopy(optimizer),
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            has_sampled = False
            model.train()
            # lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['node_idx_gnn'] = batch_data[dd][ff][
                                'node_idx_gnn'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['node_idx_feat'] = batch_data[dd][ff][
                                'node_idx_feat'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['label'] = batch_data[dd][ff][
                                'label'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['att_idx'] = batch_data[dd][ff][
                                'att_idx'].pin_memory().to(gpu_id,
                                                           non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append((data, ))

                    if batch_fwd:
                        train_loss = model(*batch_fwd).mean()
                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

            if (epoch + 1) % 20 == 0 and not has_sampled:
                has_sampled = True
                print('saving graphs')
                model.eval()
                graphs_gen = [
                    get_graph(aa.cpu().data.numpy())
                    for aa in model.module._sampling(10)
                ]
                model.train()

                vis_graphs = []
                for gg in graphs_gen:
                    CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                    CGs = sorted(CGs,
                                 key=lambda x: x.number_of_nodes(),
                                 reverse=True)
                    vis_graphs += [CGs[0]]

                total = len(vis_graphs)  #min(3, len(vis_graphs))
                draw_graph_list(vis_graphs[:total],
                                2,
                                int(total // 2),
                                fname='sample/gran_%d.png' % epoch,
                                layout='spring')

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1
Пример #13
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            # create graph classifier
            graph_classifier = GraphSAGE(3, 2, 3, 32, 'add')
            # graph_classifier = DiffPool(3, 2, max_num_nodes=630)
            # graph_classifier = DGCNN(3, 2, 'PROTEINS_full')
            graph_classifier.load_state_dict(
                torch.load('output/MODEL_PROTEINS.pkl'))

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)
                graph_classifier = graph_classifier.to(self.device)

            model.eval()
            graph_classifier.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            num_test_batch = 5000

            gen_run_time = []
            graph_acc_count = 0
            # for ii in tqdm(range(1)):
            #     with torch.no_grad():
            #         start_time = time.time()
            #         input_dict = {}
            #         input_dict['is_sampling'] = True
            #         input_dict['batch_size'] = self.test_conf.batch_size
            #         input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            #         A_tmp, label_tmp = model(input_dict)
            #         gen_run_time += [time.time() - start_time]
            #         A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
            #         num_nodes_pred += [aa.shape[0] for aa in A_tmp]

            from classifier.losses import MulticlassClassificationLoss
            classifier_loss = MulticlassClassificationLoss()

            ps = []

            acc_count_by_label = {0: 0, 1: 0}

            graph_acc_count = 0
            for ii in tqdm(range(2 * num_test_batch)):
                with torch.no_grad():
                    label = ii % 2
                    graph_label = torch.tensor([label]).to('cuda').long()
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_by_group[
                        graph_label.item()]
                    input_dict['graph_label'] = graph_label

                    A_tmp, label_tmp = model(input_dict)
                    A_tmp = A_tmp[0]
                    label_tmp = label_tmp[0]

                    label_tmp = label_tmp.long()

                    lower_part = torch.tril(A_tmp, diagonal=-1)

                    x = torch.zeros((A_tmp.shape[0], 3)).to(self.device)
                    x[list(range(A_tmp.shape[0])), label_tmp] = 1

                    edge_mask = (lower_part != 0).to(self.device)
                    edges = edge_mask.nonzero().transpose(0, 1).to(self.device)
                    edges_other_way = edges[[1, 0]]
                    edges = torch.cat([edges, edges_other_way],
                                      dim=-1).to(self.device)

                    batch = torch.zeros(A_tmp.shape[0]).long().to(self.device)

                    data = Bunch(x=x,
                                 edge_index=edges,
                                 batch=batch,
                                 y=graph_label,
                                 edge_weight=None)

                    n_nodes = batch.shape[0]
                    n_edges = edges.shape[1]

                    output = graph_classifier(data)

                    if not isinstance(output, tuple):
                        output = (output, )

                    graph_classification_loss, graph_classification_acc = classifier_loss(
                        data.y, *output)
                    graph_acc_count += graph_classification_acc / 100

                    acc_count_by_label[label] += graph_classification_acc / 100

                    print(graph_classification_acc, graph_label)

                    if ii % 100 == 99:
                        n_graphs_each = (ii + 1) / 2
                        print("\033[92m" +
                              "Class 0: %.3f ----  Class 1: %.3f" %
                              (acc_count_by_label[0] / n_graphs_each,
                               acc_count_by_label[1] / n_graphs_each) +
                              "\033[0m")

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))
            for label in [0, 1]:
                graph_acc_count = acc_count_by_label[label]
                logger.info('Class %s: ' % (label) +
                            'Conditional graph generation accuracy = {}'.
                            format(graph_acc_count / num_test_batch))

            graphs_gen = [get_graph(aa) for aa in A_pred]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test
  def test(self):
    self.config.save_dir_train = self.test_conf.test_model_dir

    if not self.config.test.is_test_ER:
      ### load model
      model = eval(self.model_conf.name)(self.config)
      model_file = os.path.join(self.config.save_dir_train, self.test_conf.test_model_name)
      load_model(model, model_file, self.device)
      if self.use_gpu:
        model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)
      model.eval()

      if hasattr(self.config, 'complete_graph_model'):
        complete_graph_model = eval(self.config.complete_graph_model.name)(self.config.complete_graph_model)
        complete_graph_model_file = os.path.join(self.config.complete_graph_model.test_model_dir,
                                                 self.config.complete_graph_model.test_model_name)
        load_model(complete_graph_model, complete_graph_model_file, self.device)
        if self.use_gpu:
          complete_graph_model = nn.DataParallel(complete_graph_model, device_ids=self.gpus).to(self.device)
        complete_graph_model.eval()

    if self.config.test.is_test_ER or not hasattr(self.config.test, 'hard_multi') or not self.config.test.hard_multi:
      hard_thre_list = [None]
    else:
      hard_thre_list = np.arange(0.5, 1, 0.1)

    for test_hard_idx, hard_thre in enumerate(hard_thre_list):
      if self.config.test.is_test_ER:
        ### Compute Erdos-Renyi baseline
        p_ER = sum([aa.number_of_edges() for aa in self.graphs_train]) / sum([aa.number_of_nodes() ** 2 for aa in self.graphs_train])
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
      else:
        logger.info('Test pass {}. Hard threshold {}'.format(test_hard_idx, hard_thre))
        ### Generate Graphs
        A_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
          with torch.no_grad():
            start_time = time.time()
            input_dict = {}
            input_dict['is_sampling'] = True
            input_dict['batch_size'] = self.test_conf.batch_size
            input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            input_dict['hard_thre'] = hard_thre
            A_tmp = model(input_dict)

            if hasattr(self.config, 'complete_graph_model'):
              final_A_list = []
              for batch_idx in range(len(A_tmp)):
                new_pmf = torch.zeros(len(self.num_nodes_pmf_train))
                max_prob = 0.
                max_prob_num_nodes = None
                for num_nodes, prob in enumerate(self.num_nodes_pmf_train):
                  if prob == 0.:
                    continue
                  tmp_data = {}
                  A_tmp_tmp = A_tmp[batch_idx][:num_nodes, :num_nodes]
                  tmp_data['adj'] = F.pad(
                    A_tmp_tmp,
                    (0, self.config.complete_graph_model.model.max_num_nodes-num_nodes, 0, 0),
                    'constant', value=.0)[None, None, ...]

                  adj = torch.tril(A_tmp_tmp, diagonal=-1)
                  adj = adj + adj.transpose(0, 1)
                  edges = adj.to_sparse().coalesce().indices()
                  tmp_data['edges'] = edges.t()
                  tmp_data['subgraph_idx'] = torch.zeros(num_nodes).long().to(self.device, non_blocking=True)

                  tmp_logit = complete_graph_model(tmp_data)
                  new_pmf[num_nodes] = torch.sigmoid(tmp_logit).item()

                  if new_pmf[num_nodes] > max_prob:
                    max_prob = new_pmf[num_nodes]
                    max_prob_num_nodes = num_nodes

                  if new_pmf[num_nodes] <= 0.9:
                    new_pmf[num_nodes] = 0.

                if (new_pmf == 0.).all():
                  logger.info('(new_pmf == 0.).all(), use {} nodes with max prob {}'.format(max_prob_num_nodes, max_prob))
                  final_num_nodes = max_prob_num_nodes
                else:
                  final_num_nodes = torch.multinomial(new_pmf, 1).item()
                final_A_list.append(
                  A_tmp_tmp[:final_num_nodes, :final_num_nodes]
                )
              A_tmp = final_A_list
            gen_run_time += [time.time() - start_time]
            A_pred += [aa.cpu().numpy() for aa in A_tmp]
            num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        print('num_nodes_pred', num_nodes_pred)
        logger.info('Average test time per mini-batch = {}'.format(
          np.mean(gen_run_time)))

        graphs_gen = [get_graph(aa) for aa in A_pred]

      ### Visualize Generated Graphs
      if self.is_vis:
        num_col = self.vis_num_row
        num_row = self.num_vis // num_col
        test_epoch = self.test_conf.test_model_name
        test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
        if hard_thre is not None:
          save_name = os.path.join(self.config.save_dir_train, '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
            self.config.test.test_model_name[:-4], test_epoch,
            int(round(hard_thre*10))))
          save_name2 = os.path.join(self.config.save_dir,
                                   '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch,
                                     int(round(hard_thre * 10))))
        else:
          save_name = os.path.join(self.config.save_dir_train,
                                   '{}_gen_graphs_epoch_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch))
          save_name2 = os.path.join(self.config.save_dir,
                                    '{}_gen_graphs_epoch_{}.png'.format(
                                      self.config.test.test_model_name[:-4], test_epoch))

        # remove isolated nodes for better visulization
        graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]

        if self.better_vis:
          # actually not necessary with the following largest connected component selection
          for gg in graphs_pred_vis:
            gg.remove_nodes_from(list(nx.isolates(gg)))

        # display the largest connected component for better visualization
        vis_graphs = []
        for gg in graphs_pred_vis:
          if self.better_vis:
            CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
            CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            vis_graphs += [CGs[0]]
          else:
            vis_graphs += [gg]
        print('number of nodes after better vis', [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

        if self.is_single_plot:
          # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
          draw_graph_list(vis_graphs, num_row, num_col, fname=save_name2, layout='spring')
        else:
          # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
          draw_graph_list_separate(vis_graphs, fname=save_name2[:-4], is_single=True, layout='spring')

        if test_hard_idx == 0:
          save_name = os.path.join(self.config.save_dir_train, 'train_graphs.png')

          if self.is_single_plot:
            draw_graph_list(
              self.graphs_train[:self.num_vis],
              num_row,
              num_col,
              fname=save_name,
              layout='spring')
          else:
            draw_graph_list_separate(
              self.graphs_train[:self.num_vis],
              fname=save_name[:-4],
              is_single=True,
              layout='spring')

      ### Evaluation
      if self.config.dataset.name in ['lobster']:
        acc = eval_acc_lobster_graph(graphs_gen)
        logger.info('Validity accuracy of generated graphs = {}'.format(acc))

      num_nodes_gen = [len(aa) for aa in graphs_gen]

      # Compared with Validation Set
      num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
      mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(self.graphs_dev, graphs_gen,
                                                                                       degree_only=False)
      mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      # Compared with Test Set
      num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
      mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(self.graphs_test, graphs_gen,
                                                                                           degree_only=False)
      mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      logger.info(
        "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_dev),
                                                                                                   Decimal(mmd_degree_dev),
                                                                                                   Decimal(mmd_clustering_dev),
                                                                                                   Decimal(mmd_4orbits_dev),
                                                                                                   Decimal(mmd_spectral_dev)))
      logger.info(
        "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_test),
                                                                                                   Decimal(mmd_degree_test),
                                                                                                   Decimal(mmd_clustering_test),
                                                                                                   Decimal(mmd_4orbits_test),
                                                                                                   Decimal(mmd_spectral_test)))
    def train(self):
        torch.autograd.set_detect_anomaly(True)

        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,  # true for grid
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)
        criterion = nn.BCEWithLogitsLoss()

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)
            criterion = criterion.cuda()
        model.train()

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        # TODO: not used?
        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        best_acc = 0.
        # resume training
        # TODO: record resume_epoch to the saved file
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            model.train()
            train_iterator = train_loader.__iter__()

            avg_acc_whole_epoch = 0.
            cnt = 0.

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                avg_train_loss = .0
                avg_acc = 0.
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            # data['node_idx_gnn'] = batch_data[dd][ff]['node_idx_gnn'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['node_idx_feat'] = batch_data[dd][ff]['node_idx_feat'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['label'] = batch_data[dd][ff]['label'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['att_idx'] = batch_data[dd][ff]['att_idx'].pin_memory().to(gpu_id, non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['complete_graph_label'] = batch_data[dd][ff][
                                'complete_graph_label'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append((data, ))

                    pred = model(*batch_fwd)
                    label = data['complete_graph_label'][:, None]
                    train_loss = criterion(pred, label).mean()
                    train_loss.backward()

                    pred = (torch.sigmoid(pred) > 0.5).type_as(label)
                    avg_acc += (pred.eq(label)).float().mean().item()

                    avg_train_loss += train_loss.item()

                    # assign gradient

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                lr_scheduler.step()
                avg_train_loss /= self.dataset_conf.num_fwd_pass  # num_fwd_pass always 1
                avg_acc /= self.dataset_conf.num_fwd_pass

                avg_acc_whole_epoch += avg_acc
                cnt += len(data['complete_graph_label'])

                # reduce
                self.writer.add_scalar('train_loss', avg_train_loss,
                                       iter_count)
                self.writer.add_scalar('train_acc', avg_acc, iter_count)
                results['train_loss'] += [avg_train_loss]
                results['train_acc'] += [avg_acc]
                results['train_step'] += [iter_count]

                # if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                #   logger.info("NLL Loss @ epoch {:04d} iteration {:08d} = {}\tAcc = {}".format(epoch + 1, iter_count, train_loss, avg_acc))

            avg_acc_whole_epoch /= cnt
            is_new_best = avg_acc_whole_epoch > best_acc
            if is_new_best:
                logger.info('!!! New best')
                best_acc = avg_acc_whole_epoch
            logger.info("Avg acc = {} @ epoch {:04d}".format(
                avg_acc_whole_epoch, epoch + 1))

            # snapshot model
            if (epoch +
                    1) % self.train_conf.snapshot_epoch == 0 or is_new_best:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1
Пример #16
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        # if self.config.test.is_test_ER:
        p_ER = sum([aa.number_of_edges()
                    for aa in self.graphs_train]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
        # graphs_baseline = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        temp = []
        for G in graphs_gen:
            G.remove_nodes_from(list(nx.isolates(G)))
            if G is not None:
                #  take the largest connected component
                CGs = [G.subgraph(c) for c in nx.connected_components(G)]
                CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
                temp.append(CGs[0])
        # graphs_gen = temp
        graphs_baseline = temp

        # else:
        ### load model
        model = eval(self.model_conf.name)(self.config)
        model_file = os.path.join(self.config.save_dir, self.test_conf.test_model_name)
        load_model(model, model_file, self.device)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)

        model.eval()

        ### Generate Graphs
        A_pred = []
        node_label_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
            with torch.no_grad():
                start_time = time.time()
                input_dict = {}
                input_dict['is_sampling'] = True
                input_dict['batch_size'] = self.test_conf.batch_size
                input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                A_tmp, node_label_tmp = model(input_dict)
                gen_run_time += [time.time() - start_time]
                A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                node_label_pred += [ll.data.cpu().numpy() for ll in node_label_tmp]
                num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        # print(len(A_pred), type(A_pred[0]))

        logger.info('Average test time per mini-batch = {}'.format(np.mean(gen_run_time)))

        # print(A_pred[0].shape,
        #       get_graph(A_pred[0]).number_of_nodes(),
        #       get_graph_with_labels(A_pred[0], node_label_pred[0]).number_of_nodes())
        # print(A_pred[0])
        # return
        # graphs_gen = [get_graph(aa) for aa in A_pred]
        graphs_gen = [get_graph_with_labels(aa, ll) for aa, ll in zip(A_pred, node_label_pred)]
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)  # for adding bipartite graph attribute

        # return

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
            save_name = os.path.join(self.config.save_dir, '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                self.config.test.test_model_name[:-4], test_epoch, self.block_size, self.stride))

            # remove isolated nodes for better visulization
            # graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]
            graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen if not gg.graph['bipartite']]
            logger.info('Number of not bipartite graphs: {} / {}'.format(len(graphs_pred_vis), len(graphs_gen)))
            # if self.better_vis:
            #     for gg in graphs_pred_vis:
            #         gg.remove_nodes_from(list(nx.isolates(gg)))

            # # display the largest connected component for better visualization
            # vis_graphs = []
            # for gg in graphs_pred_vis:
            #     CGs = [gg.subgraph(c) for c in nx.connected_components(gg)] # nx.subgraph makes a graph frozen!
            #     CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            #     vis_graphs += [CGs[0]]
            vis_graphs = graphs_pred_vis

            if self.is_single_plot:
                draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
            else:  #XD: using this for now
                draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')
            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis], num_row, num_col, fname=save_name, layout='spring')
                print('training single plot saved at:', save_name)
            else:  #XD: using this for now
                graph_list_train = [get_graph_from_nx(G) for G in self.graphs_train[:self.num_vis]]
                draw_graph_list_separate(graph_list_train, fname=save_name[:-4], is_single=True, layout='spring')
                print('training plots saved individually at:', save_name[:-4])
        return

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info('Validity accuracy of generated graphs = {}'.format(acc))
        '''=====XD====='''
        ## graphs_gen = [generate_random_baseline_single(len(aa)) for aa in graphs_gen]  # use this line for random baseline MMD scores. Remember to comment it later!
        # draw_hists(self.graphs_test, graphs_baseline, graphs_gen)
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)
        # logger.info('Generated {} graphs, valid percentage = {:.2f}, bipartite percentage = {:.2f}'.format(
        #     len(graphs_gen), valid_pctg, bipartite_pctg))
        # # return
        '''=====XD====='''

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # # Compared with Validation Set
        # num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
        # mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev = evaluate(self.graphs_dev, graphs_gen, degree_only=False)
        # mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)
        # logger.info("Validation MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev))

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test, mmd_mean_degree_connectivity_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd(
            [np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".
            format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test,
                   mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test,
                   mmd_mean_degree_connectivity_test))
    def test(self):
        with torch.no_grad():
            ### create data loader
            test_dataset = eval(self.dataset_conf.loader_name)(
                self.config, self.graphs_test, tag='test')
            test_loader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=self.train_conf.batch_size,
                shuffle=False,  # true for grid
                num_workers=self.train_conf.num_workers,
                collate_fn=test_dataset.collate_fn,
                drop_last=False)

            self.config.save_dir_train = self.test_conf.test_model_dir

            ### load model
            model = eval(self.model_conf.name)(self.config)
            criterion = nn.BCEWithLogitsLoss()
            model_file = os.path.join(self.config.save_dir_train,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)
                criterion = criterion.cuda()

            model.eval()

            test_iterator = test_loader.__iter__()

            iter_count = 0

            total_count = 0.
            total_avg_test_loss = 0.
            total_avg_test_acc = 0.

            for inner_iter in range(len(test_loader) // self.num_gpus):
                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = test_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            # data['node_idx_gnn'] = batch_data[dd][ff]['node_idx_gnn'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['node_idx_feat'] = batch_data[dd][ff]['node_idx_feat'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['label'] = batch_data[dd][ff]['label'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['att_idx'] = batch_data[dd][ff]['att_idx'].pin_memory().to(gpu_id, non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['complete_graph_label'] = batch_data[dd][ff][
                                'complete_graph_label'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append(data)

                    pred = model(*batch_fwd)
                    label = data['complete_graph_label'][:, None]

                    pred = (torch.sigmoid(pred) > 0.5).type_as(label)
                    test_acc = (pred.eq(label)).float().mean().item()
                    test_loss = criterion(pred, label).mean().item()

                    total_count += pred.size(0)
                    total_avg_test_loss += test_loss
                    total_avg_test_acc += test_acc

                # reduce
                self.writer.add_scalar('test_loss', test_loss, iter_count)
                self.writer.add_scalar('train_acc', test_acc, iter_count)

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "Test NLL Loss @ iteration {:08d} = {}\tAcc = {}".
                        format(iter_count, test_loss, test_acc))

            logger.info("Test final avg NLL Loss = {} Acc = {}".format(
                total_avg_test_loss / total_count,
                total_avg_test_acc / total_count))
            self.writer.close()
Пример #18
0
    def train(self):
        # create data loader
        train_dataset = Citation(self.dataset_conf.path,
                                 feat_dim_pca=self.model_conf.feat_dim,
                                 dataset_name=self.dataset_conf.name,
                                 split='train',
                                 train_ratio=self.dataset_conf.train_ratio,
                                 use_rand_split=self.dataset_conf.rand_split,
                                 seed=self.config.seed)
        val_dataset = Citation(self.dataset_conf.path,
                               feat_dim_pca=self.model_conf.feat_dim,
                               dataset_name=self.dataset_conf.name,
                               split='val',
                               train_ratio=self.dataset_conf.train_ratio,
                               use_rand_split=self.dataset_conf.rand_split,
                               seed=self.config.seed)
        train_loader = DataLoader(train_dataset,
                                  batch_size=self.train_conf.batch_size,
                                  shuffle=self.train_conf.shuffle,
                                  num_workers=self.train_conf.num_workers,
                                  drop_last=False)
        val_loader = DataLoader(val_dataset,
                                batch_size=self.train_conf.batch_size,
                                shuffle=False,
                                num_workers=self.train_conf.num_workers,
                                drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)

        # create optimizer
        params = model.parameters()
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_steps,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        if self.train_conf.is_resume:
            load_model(model,
                       self.train_conf.resume_model,
                       optimizer=optimizer)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).cuda()

        # Training Loop
        iter_count = 0
        best_val_acc = .0
        results = defaultdict(list)
        for epoch in range(self.train_conf.max_epoch):
            # validation
            if (epoch + 1) % self.train_conf.valid_epoch == 0 or epoch == 0:
                model.eval()
                val_loss = []
                total, correct = .0, .0
                for node_feat, node_label, edge, mask in val_loader:
                    if self.use_gpu:
                        node_feat, node_label, edge, mask = node_feat.cuda(
                        ), node_label.cuda(), edge.cuda(), mask.cuda()

                    node_feat, node_label, edge, mask = node_feat.float(
                    ), node_label.long(), edge.long(), mask.byte()

                    node_logit, node_label, _, curr_loss, _ = model(
                        edge, node_feat, target=node_label, mask=mask)
                    val_loss += [float(curr_loss.data.cpu().numpy())]
                    _, predicted = torch.max(node_logit.data, 1)
                    total += node_label.size(0)
                    correct += predicted.eq(
                        node_label.data).cpu().numpy().sum()

                val_loss = float(np.mean(val_loss))
                val_acc = 100.0 * correct / total

                # save best model
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    snapshot(model,
                             optimizer,
                             self.config,
                             epoch + 1,
                             tag='best')

                logger.info("Avg. Validation Loss = {}".format(val_loss))
                logger.info("Validation Accuracy = {}".format(val_acc))
                logger.info("Current Best Validation Accuracy = {}".format(
                    best_val_acc))
                results['val_loss'] += [val_loss]
                results['val_acc'] += [val_acc]
                model.train()

            # training
            lr_scheduler.step()
            for node_feat, node_label, edge, mask in train_loader:
                if self.use_gpu:
                    node_feat, node_label, edge, mask = node_feat.cuda(
                    ), node_label.cuda(), edge.cuda(), mask.cuda()

                node_feat, node_label, edge, mask = node_feat.float(
                ), node_label.long(), edge.long(), mask.byte()
                # optimizer.zero_grad()

                node_logit, _, diff_norm, train_loss, grad_w = model(
                    edge, node_feat, target=node_label, mask=mask)

                # assign gradient
                for pp, ww in zip(model.parameters(), grad_w):
                    pp.grad = ww

                optimizer.step()
                train_loss = float(train_loss.data.cpu().numpy())
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                # display loss
                if (iter_count + 1) % self.train_conf.display_iter == 0:
                    logger.info(
                        "Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count + 1, train_loss))
                    tmp_key = 'diff_norm_{}'.format(iter_count + 1)
                    results[tmp_key] = diff_norm.data.cpu().numpy().tolist()

                iter_count += 1

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model, optimizer,
                         self.config, epoch + 1)

        results['best_val_acc'] += [best_val_acc]
        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))

        return best_val_acc
Пример #19
0
import torch
from torch.utils.data import DataLoader

from utils.data_reader import get_train_dev_test_data, get_review_dict
from utils.data_set import NarreDataset
from utils.log_hepler import logger
from utils.train_helper import load_model, eval_model

train_data, dev_data, test_data = get_train_dev_test_data()
model = load_model("model/checkpoints/NarreModel_20200606153827.pt")
model.config.device = "cuda:1"
model.to(model.config.device)
loss = torch.nn.MSELoss()

review_by_user, review_by_item = get_review_dict("test")
dataset = NarreDataset(test_data, review_by_user, review_by_item, model.config)
data_iter = DataLoader(dataset,
                       batch_size=model.config.batch_size,
                       shuffle=True)

logger.info(f"Loss on test dataset: {eval_model(model, data_iter, loss)}")
Пример #20
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)

            model.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            alpha_list = []
            num_test_batch = int(
                np.ceil(self.num_test_gen / self.test_conf.batch_size))

            gen_run_time = []
            for ii in tqdm(range(num_test_batch)):
                with torch.no_grad():
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                    A_tmp, alpha_temp = model(input_dict)
                    gen_run_time += [time.time() - start_time]
                    A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                    num_nodes_pred += [aa.shape[0] for aa in A_tmp]
                    alpha_list += [aa.data.cpu().numpy() for aa in alpha_temp]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [
                get_graph(aa, alpha_list[i]) for i, aa in enumerate(A_pred)
            ]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            # Saves Graphs
            for i, gg in enumerate(graphs_pred_vis):
                G = gg
                name = os.path.join(
                    self.config.save_dir,
                    '{}_gen_graphs_epoch_{}_{}.pickle'.format(
                        self.config.test.test_model_name[:-4], test_epoch, i))
                with open(name, 'wb') as handle:
                    pickle.dump(G, handle)

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test
Пример #21
0
    def test(self):
        test_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                           split='test')
        # create data loader
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.test_conf.batch_size,
            shuffle=False,
            num_workers=self.test_conf.num_workers,
            collate_fn=test_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)
        load_model(model, self.test_conf.test_model)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).cuda()

        model.eval()
        test_loss = []
        for data in tqdm(test_loader):
            if self.use_gpu:
                data['node_feat'], data['node_mask'], data[
                    'label'] = data_to_gpu(data['node_feat'],
                                           data['node_mask'], data['label'])

                if self.model_conf.name == 'LanczosNet':
                    data['D'], data['V'] = data_to_gpu(data['D'], data['V'])
                elif self.model_conf.name == 'GraphSAGE':
                    data['nn_idx'], data['nonempty_mask'] = data_to_gpu(
                        data['nn_idx'], data['nonempty_mask'])
                elif self.model_conf.name == 'GPNN':
                    data['L'], data['L_cluster'], data['L_cut'] = data_to_gpu(
                        data['L'], data['L_cluster'], data['L_cut'])
                else:
                    data['L'] = data_to_gpu(data['L'])[0]

            with torch.no_grad():
                if self.model_conf.name == 'AdaLanczosNet':
                    pred, _ = model(data['node_feat'],
                                    data['L'],
                                    label=data['label'],
                                    mask=data['node_mask'])
                elif self.model_conf.name == 'LanczosNet':
                    pred, _ = model(data['node_feat'],
                                    data['L'],
                                    data['D'],
                                    data['V'],
                                    label=data['label'],
                                    mask=data['node_mask'])
                elif self.model_conf.name == 'GraphSAGE':
                    pred, _ = model(data['node_feat'],
                                    data['nn_idx'],
                                    data['nonempty_mask'],
                                    label=data['label'],
                                    mask=data['node_mask'])
                elif self.model_conf.name == 'GPNN':
                    pred, _ = model(data['node_feat'],
                                    data['L'],
                                    data['L_cluster'],
                                    data['L_cut'],
                                    label=data['label'],
                                    mask=data['node_mask'])
                else:
                    pred, _ = model(data['node_feat'],
                                    data['L'],
                                    label=data['label'],
                                    mask=data['node_mask'])

                curr_loss = (pred - data['label']
                             ).abs().cpu().numpy() * self.const_factor
                test_loss += [curr_loss]

        test_loss = float(np.mean(np.concatenate(test_loss)))
        logger.info("Test MAE = {}".format(test_loss))

        return test_loss
Пример #22
0
  def train(self):
    logger.debug('starting training')

    train_dataset = eval(self.dataset_conf.loader_name)(self.config, self.graphs_train, tag='train')
    
    #Get start and end tokens

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=self.batch_size,
        shuffle=self.train_conf.shuffle,
        num_workers=self.train_conf.num_workers,
        drop_last=False
        )

    model = eval(self.model_conf.model_name)(
      self.config,
      train_dataset.n_letters,
      train_dataset.seq_len
      )

    #move to gpu and parallelize
    if self.use_gpu:
      model = data_parallel.DataParallel(model, device_ids=self.gpus).to(self.device)

    model_params = filter(lambda p: p.requires_grad, model.parameters())

    #Setup optimizer
    if self.train_conf.optimizer == 'SGD':
      optimizer = optim.SGD(
          model_params,
          lr=self.train_conf.lr,
          momentum=self.train_conf.momentum,
          weight_decay=self.train_conf.wd
          )      
    elif self.train_conf.optimizer == 'Adam':
      optimizer = optim.Adam(
        model_params, 
        lr=self.train_conf.lr, 
        weight_decay=self.train_conf.wd
        )
    else:
      raise ValueError("Non-supported optimizer!")

    # lr_scheduler = optim.lr_scheduler.MultiStepLR(
    #     optimizer,
    #     milestones=self.train_conf.lr_decay_epoch,
    #     gamma=self.train_conf.lr_decay)

    # reset gradient
    # for i, p in enumerate(model.parameters()):
    #     logger.info("{}: {}".format(i, p))
    # print("-"*80)
    #criterion = nn.NLLLoss()
    criterion = nn.CrossEntropyLoss()

    # resume training
    resume_epoch = 0
    if self.train_conf.is_resume:
      resume_epoch = self.train_conf.resume_epoch
      model_file = os.path.join(self.train_conf.resume_dir,
                                self.train_conf.resume_model)
      obj = load_model(
          model.module if self.use_gpu else model,
          model_file,
          self.device,
          optimizer=optimizer
          )
      
      if self.use_gpu:
        model.module = obj['model']
      else:
        model = obj['model']

      optimizer = obj['optimizer']
      scheduler = obj['scheduler']

     
    results = defaultdict(list)
    
    for epoch in range(resume_epoch, self.train_conf.max_epoch):
      model.train()

      train_iterator = train_loader.__iter__()
      if epoch == 0:
        iter_length = len(train_iterator)
        logger.debug("Length of train loader: {}".format(iter_length))  

      avg_train_loss = .0   
      iter_count = 0  
      for _, (inp, target, ext) in enumerate(train_iterator):
        
        model.module.zero_grad()
        optimizer.zero_grad()

        iter_count += 1
        loss = .0
         
        input_tensor = inp.pin_memory().to(0, non_blocking=True)          
        target_tensor = target.pin_memory().to(0, non_blocking=True)
        ext_tensor = ext.pin_memory().to(0, non_blocking=True)
        hidden = torch.cat([model.module.initHidden().pin_memory().to(0,non_blocking=True) for _ in range(input_tensor.size(0))], dim=1)

        output, hidden = model(ext_tensor, input_tensor, hidden)

        for batch in range(output.size(0)):
          l = criterion(output[batch], target_tensor[batch])
          loss += l
        avg_train_loss += float(loss.item()) / output.size(0)

        loss.backward()
        optimizer.step()
        #lr_scheduler.step()

        if iter_count % self.train_conf.display_iter == 0 and iter_count > 1:
          avg_train_loss /= self.train_conf.display_iter
          results['train_loss'] += [avg_train_loss]
          results['train_step'] += [iter_count]

          
          logger.info("Loss @ epoch {:04d} iteration {:08d} = {}".format(epoch + 1, iter_count, avg_train_loss))

          
      #if iter_count % self.train_conf.display_code_iter == 0 and iter_count > 0:
      #Look at only the first one
      choice = random.choice(range(output.size(0)))
      file_type = self.file_ext[ext_tensor[choice].squeeze().detach().item()]
      target_char = self.tochar(target_tensor[choice])
      predict_char = self.tochar(torch.argmax(output[choice], dim=1))
      logger.info("Epoch {} Iter {} | Sample Start ----------------------".format(epoch, iter_count))
      logger.info("File Type: {}".format(file_type))
      logger.info("Predict: {}".format(''.join(predict_char)))
      logger.info("Target : {}".format(''.join(target_char)))
      logger.info("--------------------------------------------------------")
      #logger.info("output: {}".format(output[0]))

        # snapshot model
      if epoch % self.train_conf.snapshot_epoch == 0:
        logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
        
        snapshot(model.module, optimizer, self.config, epoch + 1)
        

    pickle.dump(results, open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
    
    return 1
Пример #23
0
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        # model = eval(self.model_conf.name)(self.config)
        from model.transformer import make_model
        model = make_model(max_node=self.config.model.max_num_nodes,
                           d_out=20,
                           N=7,
                           d_model=64,
                           d_ff=64,
                           dropout=0.4)  # d_out, N, d_model, d_ff, h
        # d_out=20, N=15, d_model=16, d_ff=16, dropout=0.2) # d_out, N, d_model, d_ff, h
        # d_out=20, N=3, d_model=64, d_ff=64, dropout=0.1) # d_out, N, d_model, d_ff, h

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            model.train()
            lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data += [data]

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = batch_data[dd]

                            adj, lens = data['adj'], data['lens']

                            # this is only for grid
                            # adj = adj[:, :, :100, :100]
                            # lens = [min(99, x) for x in lens]

                            adj = adj.to('cuda:%d' % gpu_id)

                            # build masks
                            node_feat, attn_mask, lens = preprocess(adj, lens)
                            batch_fwd.append(
                                (node_feat, attn_mask.clone(), lens))

                    if batch_fwd:
                        node_feat, attn_mask, lens = batch_fwd[0]
                        log_theta, log_alpha = model(*batch_fwd)

                        train_loss = model.module.mix_bern_loss(
                            log_theta, log_alpha, adj, lens)

                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

                if epoch % 50 == 0 and inner_iter == 0:
                    model.eval()
                    print('saving graphs')
                    graphs_gen = [get_graph(adj[0].cpu().data.numpy())] + [
                        get_graph(aa.cpu().data.numpy())
                        for aa in model.module.sample(
                            19, max_node=self.config.model.max_num_nodes)
                    ]
                    model.train()

                    vis_graphs = []
                    for gg in graphs_gen:
                        CGs = [
                            gg.subgraph(c) for c in nx.connected_components(gg)
                        ]
                        CGs = sorted(CGs,
                                     key=lambda x: x.number_of_nodes(),
                                     reverse=True)
                        try:
                            vis_graphs += [CGs[0]]
                        except:
                            pass

                    try:
                        total = len(vis_graphs)  #min(3, len(vis_graphs))
                        draw_graph_list(vis_graphs[:total],
                                        4,
                                        int(total // 4),
                                        fname='sample/trans_sl:%d_%d.png' %
                                        (int(model.module.self_loop), epoch),
                                        layout='spring')
                    except:
                        print('sample saving failed')

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1