Exemplo n.º 1
0
    def __init__(self, config):

        if 'bAbI' in config.dataset_dir:
            self.train_data = bAbIDataset(config.dataset_dir, config.task)
            self.train_loader = DataLoader(self.train_data,
                                           batch_size=config.batch_size,
                                           num_workers=1,
                                           shuffle=True)

            self.test_data = bAbIDataset(config.dataset_dir,
                                         config.task,
                                         train=False)
            self.test_loader = DataLoader(self.test_data,
                                          batch_size=config.batch_size,
                                          num_workers=1,
                                          shuffle=False)
        elif 'CBTest' in config.dataset_dir:
            self.train_data = CBTestDataset(config.dataset_dir,
                                            config.word_type,
                                            perc_dict=config.perc_dict)
            print("Training set size: ", self.train_data.__len__())
            self.train_loader = DataLoader(self.train_data,
                                           batch_size=config.batch_size,
                                           num_workers=1,
                                           shuffle=True)

            self.test_data = copy.deepcopy(self.train_data)
            self.test_data.set_train_test(train=False)
            print("Testing set size: ", self.test_data.__len__())
            self.test_loader = DataLoader(self.test_data,
                                          batch_size=config.batch_size,
                                          num_workers=1,
                                          shuffle=False)

        settings = {
            "use_cuda": config.cuda,
            "num_vocab": self.train_data.num_vocab,
            "embedding_dim": 20,
            "sentence_size": self.train_data.sentence_size,
            "max_hops": config.max_hops
        }

        print("Longest sentence length", self.train_data.sentence_size)
        print("Longest story length", self.train_data.max_story_size)
        print("Average story length", self.train_data.mean_story_size)
        print("Number of vocab", self.train_data.num_vocab)

        self.mem_n2n = MemN2N(settings)
        self.ce_fn = nn.CrossEntropyLoss(size_average=False)
        self.opt = torch.optim.SGD(self.mem_n2n.parameters(), lr=config.lr)
        print(self.mem_n2n)

        if config.cuda:
            self.ce_fn = self.ce_fn.cuda()
            self.mem_n2n = self.mem_n2n.cuda()

        self.start_epoch = 0
        self.config = config
Exemplo n.º 2
0
    def __init__(self, config):
        self.train_data = bAbIDataset(config.dataset_dir, config.task)
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=config.batch_size,
                                       num_workers=1,
                                       shuffle=True)

        self.test_data = bAbIDataset(config.dataset_dir,
                                     config.task,
                                     train=False)
        self.test_loader = DataLoader(self.test_data,
                                      batch_size=config.batch_size,
                                      num_workers=1,
                                      shuffle=False)

        settings = {
            "use_cuda": config.cuda,
            "num_vocab": self.train_data.num_vocab,
            "embedding_dim": 20,
            "sentence_size": self.train_data.sentence_size,
            "max_hops": config.max_hops
        }

        print("Longest sentence length", self.train_data.sentence_size)
        print("Longest story length", self.train_data.max_story_size)
        print("Average story length", self.train_data.mean_story_size)
        print("Number of vocab", self.train_data.num_vocab)

        self.mem_n2n = MemN2N(settings)
        self.ce_fn = nn.CrossEntropyLoss(size_average=False)
        self.opt = torch.optim.SGD(self.mem_n2n.parameters(),
                                   lr=config.lr,
                                   weight_decay=1e-5)
        print(self.mem_n2n)

        if config.cuda:
            self.ce_fn = self.ce_fn.cuda()
            self.mem_n2n = self.mem_n2n.cuda()

        self.start_epoch = 0
        self.config = config
Exemplo n.º 3
0
        noise = noise.to(device)
        p.grad.data.add_(noise)


def decay_learning_rate(opt, epoch, lr, decay_interval, decay_ratio):
    decay_count = max(0, epoch // decay_interval)
    lr = lr * (decay_ratio**decay_count)
    for param_group in opt.param_groups:
        param_group["lr"] = lr
    return lr


device = torch.device(
    "cuda:0" if args.use_cuda and torch.cuda.is_available() else "cpu")

train_data = bAbIDataset(args.dataset_dir, args.task)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
test_data = bAbIDataset(args.dataset_dir, args.task, train=False)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
print("Task: {} Train datset size: {}, Test dataset size: {}".format(
    args.task, len(train_data), len(test_data)))

settings = {
    "device": device,
    "num_vocab": train_data.num_vocab,
    "embedding_dim": args.embedding_dim,
    "sentence_size": train_data.sentence_size,
    "max_hops": args.max_hops
}
print("Longest sentence length", train_data.sentence_size)
print("Longest story length", train_data.max_story_size)
Exemplo n.º 4
0
def main():
    global epoch
    # Get arguments, setup,  prepare data and print some info
    args = parse()

    log_path = os.path.join("logs", args.name)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    writer = SummaryWriter(log_path)

    if args.task == 'babi':
        train_dataset = bAbIDataset(args.dataset_path, args.babi_task)
        val_dataset = bAbIDataset(args.dataset_path,
                                  args.babi_task,
                                  train=False)
    else:
        raise NotImplementedError

    # Setting up the Model
    if args.model == 'lstm':
        model = LSTM(40,
                     train_dataset.num_vocab,
                     100,
                     args.device,
                     sentence_size=max(train_dataset.sentence_size,
                                       train_dataset.query_size))
        print("Using LSTM")
    else:
        # model = REN(args.num_blocks, train_dataset.num_vocab, 100, args.device, train_dataset.sentence_size,
        #             train_dataset.query_size).to(args.device)
        model = RecurrentEntityNetwork(train_dataset.num_vocab,
                                       device=args.device,
                                       sequence_length=max(
                                           train_dataset.sentence_size,
                                           train_dataset.query_size))
        print("Using EntNet")
    if args.multi:  # TODO: Whats this?
        model = torch.nn.DataParallel(model, device_ids=args.gpu_range)

    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)
    else:
        Exception("Invalid optimizer")
    if args.cyc_lr:
        cycle_momentum = True if args.optimizer == 'sgd' else False
        lr_scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer,
            5e-5,
            args.lr,
            cycle_momentum=cycle_momentum,
            step_size_up=args.cyc_step_size_up)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=25,
                                                       gamma=0.5)

    # Before we are getting started, let's get ready to give some feedback
    print("Dataset size: ", len(train_dataset))
    print("Sentence size:", train_dataset.sentence_size)
    print("Vocab set", [
        str(i) + ': ' + str(train_dataset.vocab[i])
        for i in range(len(train_dataset.vocab))
    ])

    # Prepare Visdom
    Visdom.start()
    lr_plt = Visdom.Plot2D("Curent learning rate",
                           store_interval=1,
                           xlabel="Epochs",
                           ylabel="Learning Rate")
    # TODO: Check legend
    train_loss = Visdom.Plot2D("Loss on Train Data",
                               store_interval=1,
                               xlabel="iteration",
                               ylabel="loss",
                               legend=['one', 2, 'three'])
    train_accuracy = Visdom.Plot2D("Accuracy on Train Data",
                                   store_interval=1,
                                   xlabel="iteration",
                                   ylabel="accuracy")
    validation_loss = Visdom.Plot2D("Loss on Validation Set",
                                    store_interval=1,
                                    xlabel="epoch",
                                    ylabel="loss")
    validation_accuracy = Visdom.Plot2D("Accuracy on Validation Set",
                                        store_interval=1,
                                        xlabel="epoch",
                                        ylabel="accuracy")
    babi_text_plt = Visdom.Text("Network Output")
    train_plots = {'loss': train_loss, 'accuracy': train_accuracy}
    val_plots = {'text': babi_text_plt}

    epoch = 0

    # Register Variables and plots to save
    saver = Saver(os.path.join(args.output_path, args.name),
                  short_interval=args.save_interval)
    saver.register('train_loss', StateSaver(train_loss))
    saver.register('train_accuracy', StateSaver(train_accuracy))
    saver.register('validation_loss', StateSaver(validation_loss))
    saver.register('validation_accuracy', StateSaver(validation_accuracy))
    saver.register('lr_plot', StateSaver(lr_plt))
    saver.register("model", StateSaver(model))
    saver.register("optimizer", StateSaver(optimizer))
    saver.register("epoch", GlobalVarSaver('epoch'))
    # saver.register("train_dataset", StateSaver(train_dataset))
    # saver.register("val_dataset", StateSaver(val_dataset))

    eval_on_start = False
    print("Given model argument to load from: ", args.load_model)
    # TODO: Load learning rate scheduler
    if args.load_model:
        if not saver.load(args.load_model):
            #  model.reset_parameters()
            print('Not loading, something went wrong', args.load_model)
            pass
        else:
            eval_on_start = False

    start_epoch = epoch
    end_epoch = start_epoch + args.epochs
    model.to(args.device)

    # TODO: Use saver only on full epochs or use it on certain iteration
    """ TRAIN START """
    # Eval on Start
    if eval_on_start:
        val_result = val_dataset.eval(args, model, plots=val_plots)
        validation_loss.add_point(0, val_result['loss'])
        validation_accuracy.add_point(0, val_result['accuracy'])
        saver.write(epoch)
    for epoch in range(start_epoch, end_epoch):
        train_result = train_dataset.test(args,
                                          model,
                                          optimizer,
                                          epoch=epoch,
                                          plots=train_plots,
                                          scheduler=lr_scheduler)
        val_result = val_dataset.eval(args,
                                      model,
                                      epoch=epoch + 1,
                                      plots=val_plots)
        validation_loss.add_point(epoch, val_result['loss'])
        validation_accuracy.add_point(epoch, val_result['accuracy'])

        current_lr = None
        for param_group in optimizer.param_groups:
            current_lr = param_group['lr']
            break
        lr_plt.add_point(epoch, current_lr if current_lr else 0)

        saver.tick(epoch + 1)
        if not args.cyc_lr:
            lr_scheduler.step()

        # TODO: Add writer
        # Log
        if epoch % args.save_interval == 0 or epoch == args.epochs - 1:
            for param_group in optimizer.param_groups:
                log_lr = param_group['lr']
                break

            log = 'Epoch: [{epoch}]\t Train Loss {tl} Acc {ta}\t Val Loss {vl} Acc {va} lr {lr}'.format(
                epoch=epoch,
                tl=round(train_result['loss'], 3),
                ta=round(train_result['accuracy'], 3),
                vl=round(val_result['loss'], 3),
                va=round(val_result['accuracy'], 3),
                lr=log_lr)
            print(log)
Exemplo n.º 5
0
Arquivo: main.py Projeto: zhshLee/ggnn
# Number of epochs for training
n_epochs = 10

# GGNN hidden state size
state_dim = 4

# Number of propogation steps
n_steps = 5

# Annotation dimension. For the bAbi tasks we have one hot encoding per node.
annotation_dim = 1

# One fold of our preprocessed dataset.
dataset_path = 'babi_data/processed_1/train/%d_graphs.txt' % task_id

train_dataset = bAbIDataset(dataset_path, question_id=0, is_train=True)
train_data_loader = bAbIDataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=2)

test_dataset = bAbIDataset(dataset_path, question_id=0, is_train=False)
test_data_loader = bAbIDataLoader(test_dataset,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=2)

n_edge_types = train_dataset.n_edge_types
n_nodes = train_dataset.n_node

# The dataset has the form: [(adjacency matrix, annotation, target), ...]
def main(args):
    train_dataset = bAbIDataset(args.datadir, args.task)
    val_dataset = bAbIDataset(args.datadir, args.task, train=False)
    print("Dataset size: ", len(train_dataset))
    #print("Vocab size: ", train_dataset.num_vocab)
    print("Sentence size:", train_dataset.sentence_size)
    print("Vocab set: ", train_dataset.vocab)
    print("Story shape:", train_dataset[0][0].shape)

    train_loader = data_utils.DataLoader(train_dataset,
                                         batch_size=args.batchsize,
                                         num_workers=args.njobs,
                                         shuffle=True,
                                         pin_memory=True,
                                         timeout=300,
                                         drop_last=True)
    val_loader = data_utils.DataLoader(val_dataset,
                                       batch_size=args.batchsize,
                                       num_workers=args.njobs,
                                       shuffle=False,
                                       pin_memory=True,
                                       timeout=300,
                                       drop_last=True)

    model = REN(20, train_dataset.num_vocab, 100, args.device,
                train_dataset.sentence_size,
                train_dataset.query_size).to(args.device)
    model.init_keys()
    log_path = os.path.join("logs", args.exp_name)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    writer = SummaryWriter(log_path)

    if args.multi:
        model = torch.nn.DataParallel(model, device_ids=args.gpu_range)

    loss = torch.nn.CrossEntropyLoss().to(args.device)

    if args.cyc_lr is True:
        lr = cyclic_lr(0, 10, 2e-4, 1e-2)
    else:
        lr = args.lr
    optimizer = optim.Adam(model.parameters(),
                           lr=lr,
                           weight_decay=args.weight_decay)

    if not args.cyc_lr:
        scheduler = StepLR(optimizer, step_size=25, gamma=0.5)

    start_epoch, end_epoch = 0, args.epochs
    if args.load_model is not None and args.load_model != '':
        pt_model = torch.load(args.load_model)
        try:
            model.load_state_dict(pt_model['state_dict'])
        except:
            model = torch.nn.DataParallel(model, device_ids=[args.gpuid])
            model.load_state_dict(pt_model['state_dict'])
        optimizer.load_state_dict(pt_model['optimizer'])
        start_epoch = pt_model['epochs']
        end_epoch = start_epoch + args.epochs

    for epoch in range(start_epoch, end_epoch):
        train_result = train(model, loss, optimizer, train_loader, args)
        val_result = eval(model, loss, val_loader, args)
        if args.cyc_lr is True:
            lr = cyclic_lr(epoch * (len(train_dataset) // args.batchsize), 10,
                           2e-4, 1e-2)
            for param_group in optimizer.param_groups:
                param_group[
                    'lr'] = lr  #cyclic_learning_rate(epoch*(len(train_dataset)//args.batchsize))
        elif epoch < 200:
            scheduler.step()

        for key in train_result.keys():
            writer.add_scalar('{}_{}'.format('train', key), train_result[key],
                              epoch)
        for key in val_result.keys():
            writer.add_scalar('{}_{}'.format('val', key), val_result[key],
                              epoch)
        for param_group in optimizer.param_groups:
            writer.add_scalar('lr', param_group['lr'], epoch)
            break

        parameters = [
            _.grad.data for _ in list(
                filter(lambda p: p.grad is not None, model.parameters()))
        ]
        writer.add_scalar('gradient_norm', weight_norm(parameters), epoch)

        writer.add_scalar('output/R',
                          weight_norm(model.output.R.weight.grad.data), epoch)
        writer.add_scalar('output/H',
                          weight_norm(model.output.H.weight.grad.data), epoch)
        writer.add_scalar('story_enc/mask',
                          weight_norm(model.story_enc.mask.grad), epoch)
        writer.add_scalar('query_enc/mask',
                          weight_norm(model.query_enc.mask.grad), epoch)
        writer.add_scalar('prelu', weight_norm(model.prelu.weight.grad.data),
                          epoch)
        writer.add_scalar('embed',
                          weight_norm(model.embedlayer.weight.grad.data),
                          epoch)
        writer.add_scalar('cell/U', weight_norm(model.cell.U.weight.grad.data),
                          epoch)
        writer.add_scalar('cell/V', weight_norm(model.cell.V.weight.grad.data),
                          epoch)
        writer.add_scalar('cell/W', weight_norm(model.cell.W.weight.grad.data),
                          epoch)
        writer.add_scalar('cell/bias', weight_norm(model.cell.bias.grad),
                          epoch)

        # writer.add_scalar('param_output/R', weight_norm(model.output.R.weight.data), epoch)
        # writer.add_scalar('param_output/H', weight_norm(model.output.H.weight.data), epoch)
        # writer.add_scalar('param_story_enc/mask', weight_norm(model.story_enc.mask), epoch)
        # writer.add_scalar('param_query_enc/mask', weight_norm(model.query_enc.mask), epoch)
        # writer.add_scalar('param_prelu', weight_norm(model.prelu.weight.data), epoch)
        # writer.add_scalar('param_embed', weight_norm(model.embedlayer.weight.data), epoch)
        # writer.add_scalar('param_cell/U', weight_norm(model.cell.U.weight.data), epoch)
        # writer.add_scalar('param_cell/V', weight_norm(model.cell.V.weight.data), epoch)
        # writer.add_scalar('param_cell/W', weight_norm(model.cell.W.weight.data), epoch)
        # writer.add_scalar('param_cell/bias', weight_norm(model.cell.bias), epoch)

        if epoch % args.save_interval == 0 or epoch == args.epochs - 1:
            for param_group in optimizer.param_groups:
                log_lr = param_group['lr']
                break
            logline = 'Epoch: [{0}]\t Train Loss {1:.4f} Acc {2:.3f}  \t \
                    Val Loss {3:.4f} Acc {4:.3f} lr {5:.4f}'.format(
                epoch, train_result['loss'], train_result['accuracy'],
                val_result['loss'], val_result['accuracy'], log_lr)
            print(logline)

            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'epochs': epoch + 1,
                    'args': args,
                    'train_scores': train_result,
                    'val_scores': val_result,
                    'optimizer': optimizer.state_dict()
                },
                os.path.join(args.output_path,
                             "%s_%d.pth" % (args.exp_name, epoch)))

    return None