コード例 #1
0
    def __init__(self, db):
        super(NetworkFactory, self).__init__()

        # module_file = "models.{}".format(system_configs.snapshot_name)
        module_file = "utils.ExtremeNet"
        print("module_file: {}".format(module_file))
        nnet_module = importlib.import_module(module_file)

        self.model = DummyModule(nnet_module.model(db))
        self.loss = nnet_module.loss
        self.network = Network(self.model, self.loss)
        self.network = DataParallel(self.network,
                                    chunk_sizes=system_configs.chunk_sizes)

        total_params = 0
        for params in self.model.parameters():
            num_params = 1
            for x in params.size():
                num_params *= x
            total_params += num_params
        print("total parameters: {}".format(total_params))

        if system_configs.opt_algo == "adam":
            self.optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()))
        elif system_configs.opt_algo == "sgd":
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=system_configs.learning_rate,
                                             momentum=0.9,
                                             weight_decay=0.0001)
        else:
            raise ValueError("unknown optimizer")
コード例 #2
0
    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model = DataParallel(self.model,
                                      device_ids=gpus,
                                      chunk_sizes=chunk_sizes).to(device)
        else:
            self.model = self.model.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)
コード例 #3
0
ファイル: gran_runner.py プロジェクト: texbomb/GRAN
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            model.train()
            lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['node_idx_gnn'] = batch_data[dd][ff][
                                'node_idx_gnn'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['node_idx_feat'] = batch_data[dd][ff][
                                'node_idx_feat'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['label'] = batch_data[dd][ff][
                                'label'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['att_idx'] = batch_data[dd][ff][
                                'att_idx'].pin_memory().to(gpu_id,
                                                           non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append((data, ))

                    if batch_fwd:
                        train_loss = model(*batch_fwd).mean()
                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1
コード例 #4
0
class NetworkFactory(object):
    def __init__(self, db):
        super(NetworkFactory, self).__init__()

        # module_file = "models.{}".format(system_configs.snapshot_name)
        module_file = "utils.ExtremeNet"
        print("module_file: {}".format(module_file))
        nnet_module = importlib.import_module(module_file)

        self.model = DummyModule(nnet_module.model(db))
        self.loss = nnet_module.loss
        self.network = Network(self.model, self.loss)
        self.network = DataParallel(self.network,
                                    chunk_sizes=system_configs.chunk_sizes)

        total_params = 0
        for params in self.model.parameters():
            num_params = 1
            for x in params.size():
                num_params *= x
            total_params += num_params
        print("total parameters: {}".format(total_params))

        if system_configs.opt_algo == "adam":
            self.optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()))
        elif system_configs.opt_algo == "sgd":
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=system_configs.learning_rate,
                                             momentum=0.9,
                                             weight_decay=0.0001)
        else:
            raise ValueError("unknown optimizer")

    def cuda(self):
        self.model.cuda()

    def train_mode(self):
        self.network.train()

    def eval_mode(self):
        self.network.eval()

    def train(self, xs, ys, **kwargs):
        xs = [x.cuda(non_blocking=True) for x in xs]
        ys = [y.cuda(non_blocking=True) for y in ys]

        self.optimizer.zero_grad()
        loss = self.network(xs, ys)
        loss = loss.mean()
        loss.backward()
        self.optimizer.step()
        return loss

    def validate(self, xs, ys, **kwargs):
        with torch.no_grad():
            xs = [x.cuda(non_blocking=True) for x in xs]
            ys = [y.cuda(non_blocking=True) for y in ys]

            loss = self.network(xs, ys)
            loss = loss.mean()
            return loss

    def test(self, xs, **kwargs):
        with torch.no_grad():
            xs = [x.cuda(non_blocking=True) for x in xs]
            return self.model(*xs, **kwargs)

    def set_lr(self, lr):
        print("setting learning rate to: {}".format(lr))
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def load_pretrained_params(self, pretrained_model):
        print("loading from {}".format(pretrained_model))
        with open(pretrained_model, "rb") as f:
            params = torch.load(f)
            self.model.load_state_dict(params, strict=False)

    def load_params(self, iteration):
        cache_file = system_configs.snapshot_file.format(iteration)
        print("loading model from {}".format(cache_file))
        with open(cache_file, "rb") as f:
            params = torch.load(f)
            self.model.load_state_dict(params)

    def save_params(self, iteration):
        cache_file = system_configs.snapshot_file.format(iteration)
        print("saving model to {}".format(cache_file))
        with open(cache_file, "wb") as f:
            params = self.model.state_dict()
            torch.save(params, f)
コード例 #5
0
    def train(self):
        torch.autograd.set_detect_anomaly(True)

        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,  # true for grid
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)
        criterion = nn.BCEWithLogitsLoss()

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)
            criterion = criterion.cuda()
        model.train()

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        # TODO: not used?
        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        best_acc = 0.
        # resume training
        # TODO: record resume_epoch to the saved file
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            model.train()
            train_iterator = train_loader.__iter__()

            avg_acc_whole_epoch = 0.
            cnt = 0.

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                avg_train_loss = .0
                avg_acc = 0.
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            # data['node_idx_gnn'] = batch_data[dd][ff]['node_idx_gnn'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['node_idx_feat'] = batch_data[dd][ff]['node_idx_feat'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['label'] = batch_data[dd][ff]['label'].pin_memory().to(gpu_id, non_blocking=True)
                            # data['att_idx'] = batch_data[dd][ff]['att_idx'].pin_memory().to(gpu_id, non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['complete_graph_label'] = batch_data[dd][ff][
                                'complete_graph_label'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append((data, ))

                    pred = model(*batch_fwd)
                    label = data['complete_graph_label'][:, None]
                    train_loss = criterion(pred, label).mean()
                    train_loss.backward()

                    pred = (torch.sigmoid(pred) > 0.5).type_as(label)
                    avg_acc += (pred.eq(label)).float().mean().item()

                    avg_train_loss += train_loss.item()

                    # assign gradient

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                lr_scheduler.step()
                avg_train_loss /= self.dataset_conf.num_fwd_pass  # num_fwd_pass always 1
                avg_acc /= self.dataset_conf.num_fwd_pass

                avg_acc_whole_epoch += avg_acc
                cnt += len(data['complete_graph_label'])

                # reduce
                self.writer.add_scalar('train_loss', avg_train_loss,
                                       iter_count)
                self.writer.add_scalar('train_acc', avg_acc, iter_count)
                results['train_loss'] += [avg_train_loss]
                results['train_acc'] += [avg_acc]
                results['train_step'] += [iter_count]

                # if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                #   logger.info("NLL Loss @ epoch {:04d} iteration {:08d} = {}\tAcc = {}".format(epoch + 1, iter_count, train_loss, avg_acc))

            avg_acc_whole_epoch /= cnt
            is_new_best = avg_acc_whole_epoch > best_acc
            if is_new_best:
                logger.info('!!! New best')
                best_acc = avg_acc_whole_epoch
            logger.info("Avg acc = {} @ epoch {:04d}".format(
                avg_acc_whole_epoch, epoch + 1))

            # snapshot model
            if (epoch +
                    1) % self.train_conf.snapshot_epoch == 0 or is_new_best:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1
コード例 #6
0
ファイル: gran_runner.py プロジェクト: pclucas14/GRAN
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)
        print('number of parameters : {}'.format(
            sum([np.prod(x.shape) for x in model.parameters()])))

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)

        from copy import deepcopy
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            deepcopy(optimizer),
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            has_sampled = False
            model.train()
            # lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['node_idx_gnn'] = batch_data[dd][ff][
                                'node_idx_gnn'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['node_idx_feat'] = batch_data[dd][ff][
                                'node_idx_feat'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['label'] = batch_data[dd][ff][
                                'label'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['att_idx'] = batch_data[dd][ff][
                                'att_idx'].pin_memory().to(gpu_id,
                                                           non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append((data, ))

                    if batch_fwd:
                        train_loss = model(*batch_fwd).mean()
                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

            if (epoch + 1) % 20 == 0 and not has_sampled:
                has_sampled = True
                print('saving graphs')
                model.eval()
                graphs_gen = [
                    get_graph(aa.cpu().data.numpy())
                    for aa in model.module._sampling(10)
                ]
                model.train()

                vis_graphs = []
                for gg in graphs_gen:
                    CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                    CGs = sorted(CGs,
                                 key=lambda x: x.number_of_nodes(),
                                 reverse=True)
                    vis_graphs += [CGs[0]]

                total = len(vis_graphs)  #min(3, len(vis_graphs))
                draw_graph_list(vis_graphs[:total],
                                2,
                                int(total // 2),
                                fname='sample/gran_%d.png' % epoch,
                                layout='spring')

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1
コード例 #7
0
class Trainer(object):
    def __init__(self, model, optimizer, lr_scheduler, cfg):
        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.cfg = cfg
        self.set_device(cfg.gpus, cfg.chunk_sizes, cfg.device)
        self.metrics = ['loss', 'class_loss', 'score_loss', 'bbox_loss']

    def run_epoch(self, phase, epoch, data_loader):
        start_time = time.time()

        if phase == 'train':
            self.model.train()
        else:
            self.model.eval()
            torch.cuda.empty_cache()

        metric_loggers = {m: MetricLogger() for m in self.metrics}
        data_timer, net_timer = MetricLogger(), MetricLogger()
        num_iters = len(
            data_loader) if self.cfg.num_iters < 0 else self.cfg.num_iters
        end = time.time()

        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break

            for k in batch:
                if 'image_meta' not in k:
                    batch[k] = batch[k].to(device=self.cfg.device,
                                           non_blocking=True)
            data_timer.update(time.time() - end)
            end = time.time()

            loss, loss_stats = self.model(batch)
            loss = loss.mean()

            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(
                    filter(lambda p: p.requires_grad, self.model.parameters()),
                    self.cfg.grad_norm)
                self.optimizer.step()

            msg = 'epoch {0:<3s} {1:<5s} [{2}/{3}] '.format(
                str(epoch) + ':', phase, iter_id, num_iters)
            for m in metric_loggers:
                value = loss_stats[m].mean().item()
                metric_loggers[m].update(value, batch['image'].shape[0])
                msg += '| {} {:.3f} '.format(m, value)

            net_timer.update(time.time() - end)
            end = time.time()

            msg += '| data {:.1f}ms | net {:.1f}ms'.format(
                1000. * data_timer.val, 1000. * net_timer.val)
            if iter_id % self.cfg.print_interval == 0:
                print(msg)

            del loss, loss_stats

        if phase == 'train':
            self.lr_scheduler.step()

        stats = {k: v.avg for k, v in metric_loggers.items()}
        stats.update({'epoch_time': (time.time() - start_time) / 60.})

        return stats

    def train_epoch(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)

    @torch.no_grad()
    def val_epoch(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model = DataParallel(self.model,
                                      device_ids=gpus,
                                      chunk_sizes=chunk_sizes).to(device)
        else:
            self.model = self.model.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)
コード例 #8
0
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        # model = eval(self.model_conf.name)(self.config)
        from model.transformer import make_model
        model = make_model(max_node=self.config.model.max_num_nodes,
                           d_out=20,
                           N=7,
                           d_model=64,
                           d_ff=64,
                           dropout=0.4)  # d_out, N, d_model, d_ff, h
        # d_out=20, N=15, d_model=16, d_ff=16, dropout=0.2) # d_out, N, d_model, d_ff, h
        # d_out=20, N=3, d_model=64, d_ff=64, dropout=0.1) # d_out, N, d_model, d_ff, h

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            model.train()
            lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data += [data]

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = batch_data[dd]

                            adj, lens = data['adj'], data['lens']

                            # this is only for grid
                            # adj = adj[:, :, :100, :100]
                            # lens = [min(99, x) for x in lens]

                            adj = adj.to('cuda:%d' % gpu_id)

                            # build masks
                            node_feat, attn_mask, lens = preprocess(adj, lens)
                            batch_fwd.append(
                                (node_feat, attn_mask.clone(), lens))

                    if batch_fwd:
                        node_feat, attn_mask, lens = batch_fwd[0]
                        log_theta, log_alpha = model(*batch_fwd)

                        train_loss = model.module.mix_bern_loss(
                            log_theta, log_alpha, adj, lens)

                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

                if epoch % 50 == 0 and inner_iter == 0:
                    model.eval()
                    print('saving graphs')
                    graphs_gen = [get_graph(adj[0].cpu().data.numpy())] + [
                        get_graph(aa.cpu().data.numpy())
                        for aa in model.module.sample(
                            19, max_node=self.config.model.max_num_nodes)
                    ]
                    model.train()

                    vis_graphs = []
                    for gg in graphs_gen:
                        CGs = [
                            gg.subgraph(c) for c in nx.connected_components(gg)
                        ]
                        CGs = sorted(CGs,
                                     key=lambda x: x.number_of_nodes(),
                                     reverse=True)
                        try:
                            vis_graphs += [CGs[0]]
                        except:
                            pass

                    try:
                        total = len(vis_graphs)  #min(3, len(vis_graphs))
                        draw_graph_list(vis_graphs[:total],
                                        4,
                                        int(total // 4),
                                        fname='sample/trans_sl:%d_%d.png' %
                                        (int(model.module.self_loop), epoch),
                                        layout='spring')
                    except:
                        print('sample saving failed')

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1