if cmd_args.phase == 'train':
        optimizer = optim.Adam(classifier.parameters(),
                               lr=cmd_args.learning_rate)

        train_idxes = list(range(len(train_glist)))
        best_loss = None
        for epoch in range(cmd_args.num_epochs):
            random.shuffle(train_idxes)
            avg_loss = loop_dataset(train_glist,
                                    classifier,
                                    train_idxes,
                                    optimizer=optimizer)
            print(
                '\033[92maverage training of epoch %d: loss %.5f acc %.5f\033[0m'
                % (epoch, avg_loss[0], avg_loss[1]))

            test_loss = loop_dataset(test_glist, classifier,
                                     list(range(len(test_glist))))
            print(
                '\033[93maverage test of epoch %d: loss %.5f acc %.5f\033[0m' %
                (epoch, test_loss[0], test_loss[1]))

            if best_loss is None or test_loss[0] < best_loss:
                best_loss = test_loss[0]
                print(
                    '----saving to best model since this is the best valid loss so far.----'
                )
                torch.save(classifier.state_dict(),
                           cmd_args.save_dir + '/epoch-best.model')
                save_args(cmd_args.save_dir + '/epoch-best-args.pkl', cmd_args)
示例#2
0
        acc_train = acc_train.sum() / float(len(idx_train))
        loss_train.backward()
        optimizer.step()

        gcn.eval()
        _, loss_val, acc_val = gcn(features, orig_adj, idx_val, labels)
        acc_val = acc_val.sum() / float(len(idx_val))

        print('Epoch: {:04d}'.format(epoch + 1),
              'loss_train: {:.4f}'.format(loss_train.data[0]),
              'acc_train: {:.4f}'.format(acc_train),
              'loss_val: {:.4f}'.format(loss_val.data[0]),
              'acc_val: {:.4f}'.format(acc_val),
              'time: {:.4f}s'.format(time.time() - t))

        if best_val is None or acc_val > best_val:
            best_val = acc_val
            print(
                '----saving to best model since this is the best valid loss so far.----'
            )
            torch.save(
                gcn.state_dict(),
                cmd_args.save_dir + '/model-%s-epoch-best-%.2f.model' %
                (cmd_args.gm, cmd_args.del_rate))
            save_args(
                cmd_args.save_dir + '/model-%s-epoch-best-%.2f-args.pkl' %
                (cmd_args.gm, cmd_args.del_rate), cmd_args)

    run_test(gcn, features, orig_adj, idx_test, labels)
    # pred = gcn(features, adh)