def train(model, args): logging.info('args: %s' % str(args)) logging.info('model: %s, setup: %s' % (type(model).__name__, str(model.args))) logging.info('loading dataset') data = get_dataset(args.dataset) data.random_level = args.random_level if args.split_method == 'user': data, _ = data.split_user(args.frac) elif args.split_method == 'future': data, _ = data.split_future(args.frac) elif args.split_method == 'old': data, _, _, _ = data.split() data = data.get_seq() if type(model).__name__.startswith('DK'): topic_dic = {} kcat = Categorical(one_hot=True) kcat.load_dict(open('data/know_list.txt').read().split('\n')) for line in open('data/id_know.txt'): uuid, know = line.strip().split(' ') know = know.split(',') topic_dic[uuid] = \ torch.LongTensor(kcat.apply(None, know)) \ .max(0)[0] \ .type(torch.LongTensor) zero = [0] * len(kcat.apply(None, '<NULL>')) else: topics = get_topics(args.dataset, model.words) optimizer = torch.optim.Adam(model.parameters()) start_epoch = load_last_snapshot(model, args.workspace) if use_cuda: model.cuda() for epoch in range(start_epoch, args.epochs): logging.info(('epoch {}:'.format(epoch))) then = time.time() total_loss = 0 total_mae = 0 total_acc = 0 total_seq_cnt = 0 users = list(data) random.shuffle(users) seq_cnt = len(users) MSE = torch.nn.MSELoss() MAE = torch.nn.L1Loss() for user in users: total_seq_cnt += 1 seq = data[user] length = len(seq) optimizer.zero_grad() loss = 0 mae = 0 acc = 0 h = None for i, item in enumerate(seq): if type(model).__name__.startswith('DK'): if item.topic in topic_dic: x = topic_dic[item.topic] else: x = zero else: x = topics.get(item.topic).content x = Variable(torch.LongTensor(x)) # print(x.size()) score = Variable(torch.FloatTensor([round(item.score)])) t = Variable(torch.FloatTensor([item.time])) s, h = model(x, score, t, h) if args.loss == 'cross_entropy': loss += F.binary_cross_entropy_with_logits( s, score.view_as(s)) m = MAE(F.sigmoid(s), score).data[0] else: loss += MSE(s, score) m = MAE(s, score).data[0] mae += m acc += m < 0.5 loss /= length mae /= length acc /= length total_loss += loss.data[0] total_mae += mae total_acc += acc loss.backward() optimizer.step() if total_seq_cnt % args.save_every == 0: save_snapshot(model, args.workspace, '%d.%d' % (epoch, total_seq_cnt)) if total_seq_cnt % args.print_every != 0 and \ total_seq_cnt != seq_cnt: continue now = time.time() duration = (now - then) / 60 logging.info( '[%d:%d/%d] (%.2f seqs/min) ' 'loss %.6f, mae %.6f, acc %.6f' % (epoch, total_seq_cnt, seq_cnt, ((total_seq_cnt - 1) % args.print_every + 1) / duration, total_loss / total_seq_cnt, total_mae / total_seq_cnt, total_acc / total_seq_cnt)) then = now save_snapshot(model, args.workspace, epoch + 1)
def trainn(model, args): logging.info('model: %s, setup: %s' % (type(model).__name__, str(model.args))) logging.info('loading dataset') data = get_dataset(args.dataset) data.random_level = args.random_level if args.split_method == 'user': data, _ = data.split_user(args.frac) elif args.split_method == 'future': data, _ = data.split_future(args.frac) elif args.split_method == 'old': data, _, _, _ = data.split() data = data.get_seq() if args.input_knowledge: logging.info('loading knowledge concepts') topic_dic = {} kcat = Categorical(one_hot=True) kcat.load_dict(open(model.args['knows']).read().split('\n')) know = 'data/id_firstknow.txt' if 'first' in model.args['knows'] \ else 'data/id_know.txt' for line in open(know): uuid, know = line.strip().split(' ') know = know.split(',') topic_dic[uuid] = torch.LongTensor(kcat.apply(None, know)).max(0)[0] zero = [0] * len(kcat.apply(None, '<NULL>')) if args.input_text: logging.info('loading exercise texts') topics = get_topics(args.dataset, model.words) optimizer = torch.optim.Adam(model.parameters()) start_epoch = load_last_snapshot(model, args.workspace) if use_cuda: model.cuda() for epoch in range(start_epoch, args.epochs): logging.info('epoch {}:'.format(epoch)) then = time.time() total_loss = 0 total_mae = 0 total_acc = 0 total_seq_cnt = 0 users = list(data) random.shuffle(users) seq_cnt = len(users) MSE = torch.nn.MSELoss() MAE = torch.nn.L1Loss() for user in users: total_seq_cnt += 1 seq = data[user] seq_length = len(seq) optimizer.zero_grad() loss = 0 mae = 0 acc = 0 h = None for i, item in enumerate(seq): # score = round(item.score) if args.input_knowledge: if item.topic in topic_dic: knowledge = topic_dic[item.topic] else: knowledge = zero # knowledge = torch.LongTensor(knowledge).view(-1).type(torch.FloatTensor) # one_index = torch.nonzero(knowledge).view(-1) # expand_vec = torch.zeros(knowledge.size()).view(-1) # expand_vec[one_index] = score # cks = torch.cat([knowledge, expand_vec]).view(1, -1) knowledge = Variable(torch.LongTensor(knowledge)) # cks = Variable(cks) if args.input_text: text = topics.get(item.topic).content text = Variable(torch.LongTensor(text)) score = Variable(torch.FloatTensor([item.score])) item_time = Variable(torch.FloatTensor([item.time])) if type(model).__name__.startswith('DK'): s, h = model(knowledge, score, item_time, h) elif type(model).__name__.startswith('RA'): s, h = model(text, score, item_time, h) elif type(model).__name__.startswith('EK'): s, h = model(text, knowledge, score, item_time, h) s = s[0] if args.loss == 'cross_entropy': loss += F.binary_cross_entropy_with_logits( s, score.view_as(s)) m = MAE(F.sigmoid(s), score).data[0] else: loss += MSE(s, score) m = MAE(s, score).data[0] mae += m acc += m < 0.5 loss /= seq_length mae /= seq_length acc = float(acc) / seq_length total_loss += loss.data[0] total_mae += mae total_acc += acc loss.backward() optimizer.step() if total_seq_cnt % args.save_every == 0: save_snapshot(model, args.workspace, '%d.%d' % (epoch, total_seq_cnt)) if total_seq_cnt % args.print_every != 0 and total_seq_cnt != seq_cnt: continue now = time.time() duration = (now - then) / 60 logging.info( '[%d:%d/%d] (%.2f seqs/min) loss %.6f, mae %.6f, acc %.6f' % (epoch, total_seq_cnt, seq_cnt, ((total_seq_cnt - 1) % args.print_every + 1) / duration, total_loss / total_seq_cnt, total_mae / total_seq_cnt, total_acc / total_seq_cnt)) then = now save_snapshot(model, args.workspace, epoch + 1)