Exemplo n.º 1
0
vocab_size = len(corpus.words2id)
logging.info('vocabulary size: {}'.format(vocab_size))
model = DNN(vocab_size=vocab_size,
            embedding_size=200,
            hidden_size=512,
            embedding=embedding)

model.to(device)

loss_function = nn.CrossEntropyLoss(weight=weight)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()

total_data = len(data)
batch_size = args['batch_size']
total_step = math.ceil(total_data / batch_size)
last_training_loss = 1000000000000
for epoch in range(args.get('epoch')):
    start = 0
    training_loss = 0
    for _ in tqdm(range(int(total_step)), total=total_step):
        batch = data[start:start + batch_size]
        start += batch_size

        max_len, seq = padding(batch)  # list
        seq = torch.LongTensor(seq).to(device)
Exemplo n.º 2
0
def train():
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    data_dict, topic_dict = dh.load_data(
    )  # data_dict, [group2topic, mem2topic]

    train_data, train_label, dev_data, dev_label, test_data, test_label = dh.data_split(
        data_dict, topic_dict)
    train_dataset = dh.Dataset(train_data, train_label)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    dev_dataset = dh.Dataset(dev_data, dev_label)
    dev_loader = DataLoader(dev_dataset, batch_size=128, shuffle=True)

    lambda1 = lambda epoch: (
        epoch / args.warm_up_step
    ) if epoch < args.warm_up_step else 0.5 * (math.cos(
        (epoch - args.warm_up_step) /
        (args.n_epoch * len(train_dataset) - args.warm_up_step) * math.pi) + 1)

    model = DNN(args).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        len(train_loader) * args.n_epoch)

    global_step = 0
    best_f1 = 0.
    loss_deq = collections.deque([], args.report_step)
    for epoch in range(args.n_epoch):
        for batch in tqdm(train_loader):
            optimizer.zero_grad()
            inputs = batch['input'].to(device)
            group_topic = batch['group_topic'].to(device)
            mem_topic = batch['mem_topic'].to(device)
            labels = batch['label'].to(device)
            output = model(inputs, mem_topic, group_topic, label=labels)
            loss = output[0]
            loss.backward()
            loss_deq.append(loss.item())
            optimizer.step()
            scheduler.step()
            global_step += 1

            if global_step % args.report_step == 0:
                logger.info('loss: {}, lr: {}, epoch: {}'.format(
                    np.average(loss_deq).item(),
                    optimizer.param_groups[0]['lr'],
                    global_step / len(train_dataset)))
            if global_step % args.eval_step == 0:
                model.eval()
                eval_result = evaluation(model,
                                         data_loader=dev_loader,
                                         device=device)
                logger.info(eval_result)
                if eval_result['f1'] > best_f1:
                    torch.save(model,
                               './model/{}/torch.pt'.format(args.task_name))
                    best_f1 = eval_result['f1']
                model.train()
Exemplo n.º 3
0
def train(args, config, io):
    train_loader, validation_loader = get_loader(args, config)
    device = torch.device("cuda" if args.cuda else "cpu")
    # print(len(train_loader), len(validation_loader))

    #Try to load models
    model = DNN(args).to(device)
    """if device == torch.device("cuda"):
        model = nn.DataParallel(model)"""
    if args.model_path != "":
        model.load_state_dict(torch.load(args.model_path))

    # for para in list(model.parameters())[:-5]:
    #     para.requires_grad=False
    # print(model)

    if args.use_sgd:
        # print("Use SGD")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr * 100,
                        momentum=args.momentum,
                        weight_decay=1e-4)
    else:
        # print("Use Adam")
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
        """opt = optim.Adam([
        {'params': list(model.parameters())[:-1], 'lr':args.lr/50, 'weight_decay': 1e-4},
        {'params': list(model.parameters())[-1], 'lr':args.lr, 'weight_decay': 1e-4}
        ])
        """

    scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)

    criterion = nn.MSELoss()

    best_test_loss = 9999999.
    for epoch in range(args.epochs):
        startTime = time.time()

        ####################
        # Train
        ####################
        train_loss = 0.0
        train_dis = 0.0
        count = 0.0
        model.train()
        for data, label in train_loader:
            data, label = data.to(device), label.to(device)
            data = drop(jitter(data, device), device)
            # data = jitter(data, device, delta=0.05)
            batch_size = data.shape[0]
            logits = model(data)
            loss = criterion(logits, label)
            opt.zero_grad()
            loss.backward()
            opt.step()
            dis = distance(logits, label)
            count += batch_size
            train_loss += loss.item() * batch_size
            train_dis += dis.item() * batch_size
        scheduler.step()
        outstr = 'Train %d, loss: %.6f, distance: %.6f' % (
            epoch, train_loss * 1.0 / count, train_dis * 1.0 / count)
        io.cprint(outstr)

        ####################
        # Evaluation
        ####################
        test_loss = 0.0
        test_dis = 0.0
        count = 0.0
        model.eval()
        with torch.no_grad():
            for data, label in validation_loader:
                data, label = data.to(device), label.to(device)
                batch_size = data.shape[0]
                logits = model(data)
                loss = criterion(logits, label)
                dis = distance(logits, label)
                count += batch_size
                test_loss += loss.item() * batch_size
                test_dis += dis.item() * batch_size
        outstr = 'Test %d, loss: %.6f, distance: %.6f' % (
            epoch, test_loss * 1.0 / count, test_dis * 1.0 / count)
        io.cprint(outstr)
        if test_loss <= best_test_loss:
            best_test_loss = test_loss
            torch.save(model.state_dict(),
                       'checkpoints/%s/models/model.t7' % args.exp_name)
            torch.save(model, (config.root + config.model_path))
        io.cprint('Time: %.3f sec' % (time.time() - startTime))
def train(args, config, io):
    train_loader, validation_loader, unlabelled_loader = get_loader(
        args, config)

    device = torch.device("cuda" if args.cuda else "cpu")

    #Try to load models
    model = DNN(args).to(device)
    ema_model = DNN(args).to(device)
    for param in ema_model.parameters():
        param.detach_()
    if device == torch.device("cuda"):
        model = nn.DataParallel(model)
        ema_model = nn.DataParallel(ema_model)
    if args.model_path != "":
        model.load_state_dict(torch.load(args.model_path))
        ema_model.load_state_dict(torch.load(args.model_path))

    if args.use_sgd:
        print("Use SGD")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr * 100,
                        momentum=args.momentum,
                        weight_decay=1e-4)
    else:
        print("Use Adam")
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

    scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr)

    criterion = nn.MSELoss()
    consistency_criterion = nn.MSELoss()

    best_test_loss = 9999999.
    global_step = 0
    for epoch in range(args.epochs):
        startTime = time.time()

        ####################
        # Train
        ####################
        train_loss = 0.0
        count = 0.0
        model.train()
        ema_model.train()
        i = -1
        for (data, label), (u, _) in zip(cycle(train_loader),
                                         unlabelled_loader):
            i = i + 1
            if data.shape[0] != u.shape[0]:
                bt_size = np.minimum(data.shape[0], u.shape[0])
                data = data[0:bt_size]
                label = label[0:bt_size]
                u = u[0:bt_size]
            data, label, u = data.to(device), label.to(device), u.to(device)
            batch_size = data.shape[0]
            logits = model(data)
            class_loss = criterion(logits, label)

            u_student = jitter(u, device)
            u_teacher = jitter(u, device)
            logits_unlabeled = model(u_student)
            ema_logits_unlabeled = ema_model(u_teacher)
            ema_logits_unlabeled = Variable(ema_logits_unlabeled.detach().data,
                                            requires_grad=False)
            consistency_loss = consistency_criterion(logits_unlabeled,
                                                     ema_logits_unlabeled)
            if epoch < args.consistency_rampup_starts:
                consistency_weight = 0.0
            else:
                consistency_weight = get_current_consistency_weight(
                    args, args.final_consistency, epoch, i,
                    len(unlabelled_loader))

            consistency_loss = consistency_weight * consistency_loss
            loss = class_loss + consistency_loss

            opt.zero_grad()
            loss.backward()
            opt.step()

            global_step += 1
            # print(global_step)
            update_ema_variables(model, ema_model, args.ema_decay, global_step)

            count += batch_size
            train_loss += loss.item() * batch_size
        scheduler.step()
        outstr = 'Train %d, loss: %.6f' % (epoch, train_loss * 1.0 / count)
        io.cprint(outstr)

        ####################
        # Evaluation
        ####################
        test_loss = 0.0
        count = 0.0
        model.eval()
        ema_model.eval()
        for data, label in validation_loader:
            data, label = data.to(device), label.to(device)
            batch_size = data.shape[0]
            logits = ema_model(data)
            loss = criterion(logits, label)
            count += batch_size
            test_loss += loss.item() * batch_size
        outstr = 'Test %d, loss: %.6f' % (epoch, test_loss * 1.0 / count)
        io.cprint(outstr)
        if test_loss <= best_test_loss:
            best_test_loss = test_loss
            torch.save(ema_model.state_dict(),
                       'checkpoints/%s/models/model.t7' % args.exp_name)
            torch.save(ema_model, (config.root + config.model_path))
        io.cprint('Time: %.3f sec' % (time.time() - startTime))