IDA*算法

2018年04月21日 417点热度 0人点赞 0条评论

起因

上节人工智能课讲了IDA*算法在八数码问题中的应用,但可能但是没有把算法的详细过程给将清楚,导致了某些人的误解。昨晚和柏同学,赵同学简单讨论了一下,刚刚又在google上查了一些文献,也看到有人提了相似的问题在这里做个总结。

算法概述

和A*算法一样,IDA*也引入了启发式函数$f(n)=g(n)+h(n)$,其中$g(n)$代表了我们从初始状态到当前状态的实际代价,通常我们把迭代深度作为$g(n)$ ,而$h(n)$代表了从当前状态到目标状态的预估计代价。在八数码问题中,我们通常将曼哈顿距离之和作为$h(n)$,算法大致流程如下:

  1. 根据初始状态$s_0$,计算出上界$limit=f(s_0)$
  2. 对初始状态$s_0$进行DFS
  3. 若DFS过程中达到目标状态,则return
  4. 对DFS过程中所有$f(s_i)>limit$的branch进行剪枝
  5. 如果当前$limit$下,没有找到解,那么$limit$++,回到过程2

IDA*的算法流程还是比较清晰的,其本身相较A*也有:消耗的Memory低、无需维护open表和closed表、无需判重等优点,而且在某些时候还会比A*来得快,这是维基百科上的伪代码:

疑惑

网上几乎所有的伪代码以及算法实现,都是直接判断当前节点是否是目标节点,如果是目标节点则return,那么会不会遇到这种情况呢:即我们当前只搜了根节点的第一分支就找到了一种解,然而存不存在下图这种情况呢?

即在根节点的第三分支的某一层,也存在一个目标节点,且它比我们第一个找到的解更优呢?由于DFS对子节点的遍历顺序并没有一个规定的优先级,起初我认为这种情况是可能存在的。但是在后来在用IDA*写最短路的时候意识到,如果存在更短的路径17,那么在=17的时候就应该被搜到,而不会在=18的时候才会被访问到。因此如上问题不存在,找到的第一个node就必定是最短路径。

Kotlin代码

import java.lang.Integer.min

fun main(args: Array<String>) {

    class Node(newX: Int, newY: Int) {
        var x = newX
        var y = newY
        var value = 0
        var depth = 0

        fun calculate(goal: Node) {
            value = depth + Math.abs(x - goal.x) + Math.abs(y - goal.y)
        }
    }

    fun isEqual(a: Node, b: Node) = a.x == b.x && a.y == b.y

    val m = 10
    val n = 10
    val maze = arrayOf(       //S是起点,G是终点
            charArrayOf('1', '1', '1', '0', '0', '0', '0', '0', '1', 'G'),
            charArrayOf('0', '0', '1', '1', '1', '0', '0', '0', '0', '0'),
            charArrayOf('0', '0', '1', '1', '1', '0', '0', '0', '0', '0'),
            charArrayOf('0', '0', '1', '0', '0', '0', '0', '1', '1', '1'),
            charArrayOf('0', '0', '0', '1', '0', '1', '0', '0', '0', '1'),
            charArrayOf('0', '1', '0', '1', '0', '1', '1', '1', '0', '1'),
            charArrayOf('0', '1', '0', '1', '0', '0', '0', '1', '0', '0'),
            charArrayOf('0', '1', '0', '1', '1', '0', '0', '1', '0', '0'),
            charArrayOf('0', '1', '0', '0', '1', '1', '1', '1', '0', '0'),
            charArrayOf('S', '1', '1', '0', '0', '0', '0', '0', '0', '0')
    )

    val direction = arrayOf(     //上左右下
            intArrayOf(-1, 0),
            intArrayOf(0, -1),
            intArrayOf(0, 1),
            intArrayOf(1, 0)
    )

    var limit: Int
    var minValue: Int = Int.MAX_VALUE
    var flag = false

    val tempPath = Array(100, { _ -> -1})
    val path = ArrayList<Int>()

    lateinit var start: Node
    lateinit var goal: Node

    for (i in 0 until m) {
        for (j in 0 until n) {
            if (maze[i][j] == 'S') {
                start = Node(i, j)      //起点
            }
            if (maze[i][j] == 'G') {
                goal = Node(i, j)       //终点
            }
        }
    }
    start.calculate(goal)
    limit = start.value

    fun dfs(node: Node, preMove: Int) {
        if (isEqual(node, goal)) {
            flag = true
            for (item in tempPath) {
                if (item != -1)
                    path.add(item)
                else
                    break
            }
            return
        }
        for (i in 0..3) {
            if (i + preMove == 3 && node.depth > 0)
                continue
            val x = node.x + direction[i][0]
            val y = node.y + direction[i][1]
            if (x in 0..(m - 1) && y in 0..(n - 1) && maze[x][y] != '1') {
                val newNode = Node(x, y)
                newNode.depth = node.depth + 1
                newNode.calculate(goal)

                if (newNode.value <= limit) {
                    tempPath[node.depth] = i
                    dfs(newNode, i)
                    if (flag)
                        return
                } else {
                    minValue = min(minValue, newNode.value)
                }
            }
        }
    }

    val startTime = System.currentTimeMillis()

    do {
        minValue = Int.MAX_VALUE
        dfs(start,  0)
        limit = minValue
    } while(!flag)
    val endTime = System.currentTimeMillis()

    println("time=${endTime - startTime}ms")
    println("length=${path.size}")
    for (item in path) {
        when(item) {
            0 -> print("U")
            1 -> print("L")
            2 -> print("R")
            3 -> print("D")
        }
    }

}

Plus

文章评论